Design of a high-level PyTorch ML library
~ 3 min read
Filed under: pytorch, ml, sw, design, opensource
This article is an attempt to think about all the various activities in a deep-learning based ML model development process and to distill those into a high-level functional and object-oriented design library based on the PyTorch framework.
At a high level, the process of ML model development for a single task can be subdivided into the following stages
- Data Setup and Feature Preprocessing: perhaps multiple datasets, augmentations,
- Model and Objective/Loss Specification: specified in code in PyTorch
- Model Training: optimization
- Performance Evaluation: performance metrics on validation set, test set
The goals of the new high-level library would be to support and simplify
- Model specification: modular ability to combine multiple models like legos
- Loss-function specification and tuning: each layer can have different and multiple objectives, can arrive from models themselves
- Model deployment tools: quantization, factorization, ONNX etc.
- Self-supervised learning and coupled with data augmentation: support for online self-supervised learning, test-time self-supervised learning
- Multi-task learning: train on multiple datasets and tasks. Some tasks can share common representations and have task specific branches subsequent to that
- Curriculum learning: a teacher (RL agent, perhaps) figures out sequence of tasks and examples within those tasks to learn better representations and/or perform better at tasks
- Multi-mode learning: when inputs are from two or more different domains (like audio, sensors, video, images, etc.) and ability to share/relate representations of same object
- ML model debugging: feature attribution, hard examples, out-of-domain indicators
- Model report card
Let’s take all of these goals and then try to create an example that allow easy use and expression of high-level ideas quickly.
High-level example using JumpML library
import jumpml
import torch.optim as optim
# datasetFile defines
# multiple datasets, train, test, validation splits
# feature processing pipelines (could include pretrained models themselves)
# transformations and augmentations
# curriculum learning options
dataloader = jumpml.createDataloader(datasetFile)
# modelFile defines
# pytorch modules and connections
# objective functions at/between layers/nodes
model = jumpml.loadModel(modelFile)
# define optimizer as usual.
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5)
# Create Learner based on learnerOptions
# initialization, EMA parameters, EWC, parameter saving
# self-supervised learning phases
# curriculum learning based on learning
# dynamic model augmentation: adding more capacity or branches dynamically
# logging learning progress
learner = jumpml.createLearner(dataloader, model, optimizer, learnerOptions)
# Self-supervised phase
learner.ssLearn()
# Supervised Learning Phase: perhaps trains only the task-specific layers... may be does some multi-task learning on some other layers.
for epochs in range(numEpochs):
learner.trainEpoch() # this could only impact the classifier branches
# Evaluation
# evalMetrics are a list of functions.
# each function takes model predictions, ground truth and calculates metrics
evalResults = jumpml.evalModel(dataloader, model, evalMetrics)
# Reports, visualization, deployment
# may run a separate distillation loop (quantization, factorization, onnx exporting)
jumpml.deployModel(evalResults, model, deploymentConfig)