Blocks for multimodal and multitask learning.
Project description
Multimodal Playground
This package attempts to standardize multimodal learning. It provides a modular and extensible interface between encoders, fusion gates, and task heads, with a consistent API.
Installation
pip install -e .
Install with dev tools (pytest, ruff):
pip install -e ".[dev]"
Run tests from the repository root:
pytest
If imports fail, ensure the package is installed as above or run PYTHONPATH=src pytest.
Example usage
model(batch)/forward(batch)returns(predictions, modality_embeddings).model.predict(batch)returns predictions only (one full forward pass; embeddings are dropped).Trainercallsmodel(batch)and uses both outputs for task losses.- For optimizers, use
iter_training_parameters(model, tasks)so stateful per-task losses (e.g. a critic) are included;model.parameters()alone can miss those weights.
import torch
from torch import nn
from multimodal.fusion import ConcatFusion
from multimodal.heads import MultiTaskLinearHead
from multimodal.model import MultimodalModel
from multimodal.tasks import ClassificationTask
from multimodal.train import Trainer, TrainerConfig, iter_training_parameters
embed_dim = 32
n_sentiment, n_topic = 3, 10 # two classification heads
fused_dim = embed_dim * 2
model = MultimodalModel(
encoders={
"vision": nn.Linear(10, embed_dim),
"text": nn.Linear(8, embed_dim),
},
fusion=ConcatFusion(dim=-1),
head=MultiTaskLinearHead(
fused_dim,
{"sentiment": n_sentiment, "topic": n_topic},
),
fusion_modality_order=["vision", "text"],
)
batch = {
"vision": torch.randn(16, 10),
"text": torch.randn(16, 8),
"sentiment_y": torch.randint(0, n_sentiment, (16,)),
"topic_y": torch.randint(0, n_topic, (16,)),
}
preds, embs = model(batch)
assert preds["sentiment"].shape == (16, n_sentiment)
assert preds["topic"].shape == (16, n_topic)
logits_only = model.predict(batch) # dict with the same two keys, no embeddings
tasks = [
ClassificationTask("sentiment", "sentiment_y"),
ClassificationTask("topic", "topic_y"),
]
optimizer = torch.optim.Adam(iter_training_parameters(model, tasks), lr=1e-3)
config = TrainerConfig(
max_epochs=2,
grad_accum_steps=1,
mixed_precision=False,
device="cpu",
)
trainer = Trainer(model, tasks, optimizer, config)
train_loader = [batch]
val_loader = [
{
"vision": torch.randn(8, 10),
"text": torch.randn(8, 8),
"sentiment_y": torch.randint(0, n_sentiment, (8,)),
"topic_y": torch.randint(0, n_topic, (8,)),
},
]
trainer.train(train_loader, val_loader=val_loader)
For GPU training, set device="cuda" and mixed_precision=True in TrainerConfig (requires a CUDA device).
Freezing encoders (TrainerConfig)
The trainer can freeze encoder weights when it is constructed (after model.to(device)):
freeze_all_encoders=True— setsrequires_grad=Falseon every submodule inmodel.encoders.freeze_encoder_ids=("vision",)— freeze only the listed encoder tower keys (must match keys inmodel.encoders). Ignored iffreeze_all_encodersis True.
from multimodal.train import DDPConfig, TrainerConfig
config = TrainerConfig(
max_epochs=2,
grad_accum_steps=1,
mixed_precision=False,
device="cpu",
freeze_encoder_ids=("vision",), # train `text` encoder + fusion + head
# freeze_all_encoders=True, # alternative: freeze every encoder
# ddp=DDPConfig(backend="nccl", sync_bn=True), # when using DDP
)
trainer = Trainer(model, tasks, optimizer, config)
Optimizers created with model.parameters() still work: frozen parameters get no gradient and are not updated. To exclude frozen tensors from the optimizer entirely, use filter(lambda p: p.requires_grad, model.parameters()).
You can still freeze manually before building the trainer if you prefer not to use these flags.
Overview
We can abstract any multimodal model into the following components:
- Encoders: each modality is encoded into a feature vector (embedding).
- Fusion (optional): a method to fuse the feature vectors into a single (or multiple) representations.
- Heads / decoders: map fused representation(s) to task-specific outputs.
In this package, each encoder maps a modality tensor to an embedding. MultimodalModel.forward runs encode → fuse → head and returns (predictions, embeddings). MultimodalModel.predict returns only predictions. List-input fusions use fusion_modality_order so modalities are concatenated (or fused) in a fixed order.
Encoders output (B, latent_dim) per modality. Fusion yields (B, fusion_dim); the head maps that to task outputs.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file multimodal_playground-0.1.1.tar.gz.
File metadata
- Download URL: multimodal_playground-0.1.1.tar.gz
- Upload date:
- Size: 32.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ad3efe38139396fc154c13cdcbb956d477426b90eb8e17b9d03eb98c08446020
|
|
| MD5 |
01662b7f9370399e2a33d56b15e2984e
|
|
| BLAKE2b-256 |
cb394d74108d6ba2ed9d93dc60dd7595f7925c76b341c3d733dd6024b4ff4df0
|
Provenance
The following attestation bundles were made for multimodal_playground-0.1.1.tar.gz:
Publisher:
publish-pypi.yml on rohankumar-1/multimodal-playground
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
multimodal_playground-0.1.1.tar.gz -
Subject digest:
ad3efe38139396fc154c13cdcbb956d477426b90eb8e17b9d03eb98c08446020 - Sigstore transparency entry: 1309477819
- Sigstore integration time:
-
Permalink:
rohankumar-1/multimodal-playground@3b0b69afb4595f7daa59bfd9440c473f895bc6aa -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/rohankumar-1
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@3b0b69afb4595f7daa59bfd9440c473f895bc6aa -
Trigger Event:
push
-
Statement type:
File details
Details for the file multimodal_playground-0.1.1-py3-none-any.whl.
File metadata
- Download URL: multimodal_playground-0.1.1-py3-none-any.whl
- Upload date:
- Size: 29.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b9ed6ae9018daf6830e4cd1178ff4976b43407477ab9e9e6df0fb73eef3c9d52
|
|
| MD5 |
deb64980277fb64397e20c3b1cb7d77c
|
|
| BLAKE2b-256 |
2e385805a2b76caad3ee4f1893ae3a5a0889c30e7430b79586be496a2da41b1c
|
Provenance
The following attestation bundles were made for multimodal_playground-0.1.1-py3-none-any.whl:
Publisher:
publish-pypi.yml on rohankumar-1/multimodal-playground
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
multimodal_playground-0.1.1-py3-none-any.whl -
Subject digest:
b9ed6ae9018daf6830e4cd1178ff4976b43407477ab9e9e6df0fb73eef3c9d52 - Sigstore transparency entry: 1309478232
- Sigstore integration time:
-
Permalink:
rohankumar-1/multimodal-playground@3b0b69afb4595f7daa59bfd9440c473f895bc6aa -
Branch / Tag:
refs/tags/v0.1.1 - Owner: https://github.com/rohankumar-1
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-pypi.yml@3b0b69afb4595f7daa59bfd9440c473f895bc6aa -
Trigger Event:
push
-
Statement type: