Decoupled and modular approach to building multi-task ML models
Project description
TorchBricks
TorchBricks builds pytorch models using small reuseable and decoupled parts - we call them bricks.
The concept is simple and flexible and allows you to more easily combine and swap out parts of the model (preprocessor, backbone, neck, head or post-processor), change the task or extend it with multiple tasks.
Basic use-case: Image classification
Let us see it in action:
from torchbricks.bricks import BrickCollection, BrickNotTrainable, BrickTrainable, Phase
class Preprocessor(nn.Module):
def forward(self, raw_input: torch.Tensor) -> torch.Tensor:
return raw_input/2
# Defining model
bricks = {
"preprocessor": BrickNotTrainable(Preprocessor(), input_names=["raw"], output_names=["processed"])
"backbone": BrickTrainable(ResNetBackbone(), input_names=["processed"], output_names=["embedding"])
"image_classification": BrickTrainable(ImageClassifier(), input_names=["embedding"], output_names=["logits"])
}
# Executing model
model = BrickCollection(bricks)
outputs = model(named_tensors={"raw": input_images}, phase=Phase.TRAIN)
print(outputs.keys())
"raw", "processed", "embedding", "logits"
Note that we are explicitly passing the required inputs and output names of each brick.
It is a simple DAG connecting the outputs of one node to inputs of the next node.
We use BrickTrainable
and BrickNotTrainable
bricks to wrap basic nn.Module
s.
In the real-world each nn.Module
would have arguments and stuff, but you get the idea.
Note also that we pass in phase=Phase.TRAIN
to explicitly specify if we are doing training, validation, test or inference.
We will get back to that later.
Basic use-case: Semantic Segmentation
After running experiments, we now realize that we also wanna do semantic segmentation. This is how it would look like:
# We can optionally keep/remove image_classification from before
bricks.pop("image_classification")
# Add upscaling and semantic segmentation nn.Modules
bricks["upscaling"] = BrickTrainable(Upscaling(), input_names=["embedding"], output_names=["embedding_upscaled"])
bricks["semantic_segmentation"] = BrickTrainable(SegmentationClassifier(), input_names=["embedding_upscaled"], output_names=["ss_logits"])
# Executing model
model = BrickCollection(bricks)
outputs = model(named_tensors={"raw": input_images}, phase=Phase.TRAIN)
print(outputs.keys())
"raw", "processed", "embedding", "embedding_upscaled", "ss_logits"
Use-case: Bricks on_step
-function for training and evaluation
In above examples, we have showed how to compose trainable and non-trainable bricks, and how a dictionary of tensors is passed to the forward function... But TorchBricks goes beyond that.
Another important feature of a brick collection is the on_step
-function to also calculate metrics and losses.
We will extend the example from before:
from torchbricks.bricks import BrickCollection, BrickNotTrainable, BrickTrainable, Phase
# Defining model
bricks = {
"preprocessor": BrickNotTrainable(Preprocessor(), input_names=["raw"], output_names=["processed"])
"backbone": BrickTrainable(ResNetBackbone(), input_names=["processed"], output_names=["embedding"])
"image_classification": BrickTrainable(ImageClassifier(), input_names=["embedding"], output_names=["logits"])
}
accuracy_metric = classification.MulticlassAccuracy(num_classes=num_classes, average='micro', multiclass=True)
bricks["accuracy"] = BrickTorchMetric(accuracy_metric, input_names=['class_prediction', 'targets'])
bricks["loss"] = BrickLoss(model=nn.CrossEntropyLoss(), input_names=['logits', 'targets'], output_names=['loss_ce'])
# We can still run the forward-pass as before - Note: The forward call does not require 'targets'
model = BrickCollection(bricks)
outputs = model(named_tensors={"raw": input_images}, phase=Phase.TRAIN)
print(outputs.keys())
"raw", "processed", "embedding", "logits"
# Example of running `on_step`. Note: `on_step` requires `targets` to calculate metrics and loss.
named_tensors = {"raw": input_images, "targets": targets}
named_outputs, losses = model.on_step(phase=Phase.TRAIN, named_tensors=named_tensors, batch_idx=0)
named_outputs, losses = model.on_step(phase=Phase.TRAIN, named_tensors=named_tensors, batch_idx=1)
named_outputs, losses = model.on_step(phase=Phase.TRAIN, named_tensors=named_tensors, batch_idx=2)
metrics = model.summarize(phase=Phase.TRAIN, reset=True)
By wrapping both core model computations, metrics and loss functions into a single brick collection, we can more easily swap between running model experiments in notebooks, trainings
We provide a forward
function to easily run model inference without targets and an on_step
function
to easily get metrics and losses in both
``
Use-case: Training with a collections of bricks
We like and love pytorch-lightning! We can avoid writing the easy-to-get-wrong training loop, write validation/test scrips, it create logs, ensures training is done efficiently on any device (CPU, GPU, TPU), on multiple devices with reduced precision and much more.
But with pytorch-lightning you need to specify a LightningModule and I find myself hiding the important stuff in the class and using multiple levels of inheritance. It can make your code unnecessarily complicated, hard to read and hard to reuse. It may also require some heavy refactoring changing to a new task or switching to multiple tasks.
With a brick collection you should rarely change or inherit your lightning module, instead you inject the model, metrics and loss functions into a lightning module. Changes to preprocessor, backbone, necks, heads, metrics and losses are done on the outside and injected into the lightning module.
Below is an example of how you could inject our brick collection into our custom LightningBrickCollection
.
The brick collection can be image classification, semantic segmentation, object detection or all of them at the same time.
create_opimtizer_func = partial(torch.optim.SGD, lr=0.05, momentum=0.9, weight_decay=5e-4)
bricks_lightning_module = LightningBrickCollection(path_experiments=path_experiments,
experiment_name=experiment_name,
brick_collection=brick_collection,
create_optimizer_func=create_opimtizer_func)
logger = WandbLogger(name=experiment_name, project=PROJECT)
trainer = Trainer(accelerator=args.accelerator, logger=logger, max_epochs=args.max_epochs)
trainer.fit(bricks_lightning_module,
train_dataloaders=data_module.train_dataloader(),
val_dataloaders=data_module.val_dataloader())
trainer.test(bricks_lightning_module, datamodule=data_module)
Nested brick collections
It can handle nested brick collections and nested dictionary of bricks.
MISSING
TorchMetric.MetricCollection
MISSING
Why should I explicitly set the train, val or test phase
MISSING
What are we missing?
- Proper
LightningBrickCollection
for other people to use - Collection of helper modules. Preprocessors, Backbones, Necks/Upsamplers, ImageClassification, SemanticSegmentation, ObjectDetection
- All the modules in the README should be easy to import as actually modules.
- Make common brick collections: BricksImageClassification, BricksSegmentation, BricksPointDetection, BricksObejctDetection
- Support preparing data in the dataloader?
- Make common Visualizations with pillow - not opencv to not blow up the required dependencies. ImageClassification, Segmentation, ObjectDetection
- Make an export to onnx function and add it to the README.md
- Minor: BrickCollections supports passing a dictionary with BrickCollections. But we should also convert a nested dictionary into a nested brick collections
- Minor: Currently,
input_names
andoutput_names
support positional arguments, but we should also support keyword arguments. - [ ]
How does it really work?
????
Install it from PyPI
pip install torchbricks
Usage
Development
Read the CONTRIBUTING.md file.
Install
conda create --name torchbricks --file conda-linux-64.lock
conda activate torchbricks
poetry install
Activating the environment
conda activate torchbricks
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchbricks-0.0.5.tar.gz
.
File metadata
- Download URL: torchbricks-0.0.5.tar.gz
- Upload date:
- Size: 12.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93ac62dff2ecc637eacd8a41e06646188523591c74ed415d0b7aa1a3a53c3be1 |
|
MD5 | 68552fa7e29c629766ad92f5f5dac370 |
|
BLAKE2b-256 | 56c062a807a2e730ce328c3c801dd26fab8e8ff9eb07e01529ed4e91128adc49 |
File details
Details for the file torchbricks-0.0.5-py3-none-any.whl
.
File metadata
- Download URL: torchbricks-0.0.5-py3-none-any.whl
- Upload date:
- Size: 8.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9091bb988e9dc9d4921ff19fa7848b078b14362f5e7c462c9b0486cbc2f86657 |
|
MD5 | 251d7fe49ba83eef337e5334739e91f9 |
|
BLAKE2b-256 | d50b831ce4555441a9beb8f1e20d100f88d7669c00927d3b2990f0125227ebaa |