Skip to main content

Blocks for multimodal and multitask learning.

Project description

Multimodal Playground

This package attempts to standardize multimodal learning. It provides a modular and extensible interface between encoders, fusion gates, and task heads, with a consistent API.

Installation

pip install -e .

Install with dev tools (pytest, ruff):

pip install -e ".[dev]"

Run tests from the repository root:

pytest

If imports fail, ensure the package is installed as above or run PYTHONPATH=src pytest.

Example usage

  • model(batch) / forward(batch) returns (predictions, modality_embeddings).
  • model.predict(batch) returns predictions only (one full forward pass; embeddings are dropped).
  • Trainer calls model(batch) and uses both outputs for task losses.
  • For optimizers, use iter_training_parameters(model, tasks) so stateful per-task losses (e.g. a critic) are included; model.parameters() alone can miss those weights.
import torch
from torch import nn

from multimodal.fusion import ConcatFusion
from multimodal.heads import MultiTaskLinearHead
from multimodal.model import MultimodalModel
from multimodal.tasks import ClassificationTask
from multimodal.train import Trainer, TrainerConfig, iter_training_parameters


embed_dim = 32
n_sentiment, n_topic = 3, 10  # two classification heads
fused_dim = embed_dim * 2

model = MultimodalModel(
    encoders={
        "vision": nn.Linear(10, embed_dim),
        "text": nn.Linear(8, embed_dim),
    },
    fusion=ConcatFusion(dim=-1),
    head=MultiTaskLinearHead(
        fused_dim,
        {"sentiment": n_sentiment, "topic": n_topic},
    ),
    fusion_modality_order=["vision", "text"],
)

batch = {
    "vision": torch.randn(16, 10),
    "text": torch.randn(16, 8),
    "sentiment_y": torch.randint(0, n_sentiment, (16,)),
    "topic_y": torch.randint(0, n_topic, (16,)),
}

preds, embs = model(batch)
assert preds["sentiment"].shape == (16, n_sentiment)
assert preds["topic"].shape == (16, n_topic)

logits_only = model.predict(batch)  # dict with the same two keys, no embeddings

tasks = [
    ClassificationTask("sentiment", "sentiment_y"),
    ClassificationTask("topic", "topic_y"),
]

optimizer = torch.optim.Adam(iter_training_parameters(model, tasks), lr=1e-3)
config = TrainerConfig(
    max_epochs=2,
    grad_accum_steps=1,
    mixed_precision=False,
    device="cpu",
)
trainer = Trainer(model, tasks, optimizer, config)

train_loader = [batch]
val_loader = [
    {
        "vision": torch.randn(8, 10),
        "text": torch.randn(8, 8),
        "sentiment_y": torch.randint(0, n_sentiment, (8,)),
        "topic_y": torch.randint(0, n_topic, (8,)),
    },
]
trainer.train(train_loader, val_loader=val_loader)

For GPU training, set device="cuda" and mixed_precision=True in TrainerConfig (requires a CUDA device).

Freezing encoders (TrainerConfig)

The trainer can freeze encoder weights when it is constructed (after model.to(device)):

  • freeze_all_encoders=True — sets requires_grad=False on every submodule in model.encoders.
  • freeze_encoder_ids=("vision",) — freeze only the listed encoder tower keys (must match keys in model.encoders). Ignored if freeze_all_encoders is True.
from multimodal.train import DDPConfig, TrainerConfig

config = TrainerConfig(
    max_epochs=2,
    grad_accum_steps=1,
    mixed_precision=False,
    device="cpu",
    freeze_encoder_ids=("vision",),  # train `text` encoder + fusion + head
    # freeze_all_encoders=True,  # alternative: freeze every encoder
    # ddp=DDPConfig(backend="nccl", sync_bn=True),  # when using DDP
)
trainer = Trainer(model, tasks, optimizer, config)

Optimizers created with model.parameters() still work: frozen parameters get no gradient and are not updated. To exclude frozen tensors from the optimizer entirely, use filter(lambda p: p.requires_grad, model.parameters()).

You can still freeze manually before building the trainer if you prefer not to use these flags.

Overview

We can abstract any multimodal model into the following components:

  1. Encoders: each modality is encoded into a feature vector (embedding).
  2. Fusion (optional): a method to fuse the feature vectors into a single (or multiple) representations.
  3. Heads / decoders: map fused representation(s) to task-specific outputs.

In this package, each encoder maps a modality tensor to an embedding. MultimodalModel.forward runs encode → fuse → head and returns (predictions, embeddings). MultimodalModel.predict returns only predictions. List-input fusions use fusion_modality_order so modalities are concatenated (or fused) in a fixed order.

Encoders output (B, latent_dim) per modality. Fusion yields (B, fusion_dim); the head maps that to task outputs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

multimodal_playground-0.1.1.tar.gz (32.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

multimodal_playground-0.1.1-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file multimodal_playground-0.1.1.tar.gz.

File metadata

  • Download URL: multimodal_playground-0.1.1.tar.gz
  • Upload date:
  • Size: 32.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for multimodal_playground-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ad3efe38139396fc154c13cdcbb956d477426b90eb8e17b9d03eb98c08446020
MD5 01662b7f9370399e2a33d56b15e2984e
BLAKE2b-256 cb394d74108d6ba2ed9d93dc60dd7595f7925c76b341c3d733dd6024b4ff4df0

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_playground-0.1.1.tar.gz:

Publisher: publish-pypi.yml on rohankumar-1/multimodal-playground

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file multimodal_playground-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for multimodal_playground-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b9ed6ae9018daf6830e4cd1178ff4976b43407477ab9e9e6df0fb73eef3c9d52
MD5 deb64980277fb64397e20c3b1cb7d77c
BLAKE2b-256 2e385805a2b76caad3ee4f1893ae3a5a0889c30e7430b79586be496a2da41b1c

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_playground-0.1.1-py3-none-any.whl:

Publisher: publish-pypi.yml on rohankumar-1/multimodal-playground

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page