Skip to main content

Train many independent PyTorch models simultaneously on a single GPU using vectorized operations

Project description

ModelBatch

Train many independent PyTorch models simultaneously on a single GPU using vectorized operations.

Python 3.9+ PyTorch 2.0+

⚠️ Current Status

ModelBatch is still in active development. Core functionality is tested and working, but the API may be subject to change.

🚀 Quick Start

Installation

From PyPI:

# recommended
uv add modelbatch

# alternative
pip install modelbatch

From source:

uv sync --dev
uv pip install -e ".[dev]"

Basic Example

import torch
from modelbatch import ModelBatch

# Create multiple models
num_models = 4  # choose the number of models to batch
models = [SimpleNet() for _ in range(num_models)]

# Wrap with ModelBatch - that's it!
mb = ModelBatch(models, lr_list=[0.001] * num_models, optimizer_cls=torch.optim.Adam)

# Train normally (but many times faster!), batched across models
for batch in dataloader:
    mb.zero_grad()
    outputs = mb(batch)
    loss = mb.compute_loss(outputs, targets)  
    loss.backward()
    mb.step()

See here for more examples.

📚 Documentation

See docs.

🛠️ Development

Environment Setup

uv sync --dev

Commands

# Tests (currently showing failures)
uv run -m pytest

# Linting  
uv run ruff check --fix . && uv run ruff format .

# Documentation
uv run mkdocs serve

📄 License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

modelbatch-0.1.0.tar.gz (45.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

modelbatch-0.1.0-py3-none-any.whl (29.0 kB view details)

Uploaded Python 3

File details

Details for the file modelbatch-0.1.0.tar.gz.

File metadata

  • Download URL: modelbatch-0.1.0.tar.gz
  • Upload date:
  • Size: 45.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for modelbatch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1671de7f667c62d35c87e88717e79c8d358f55309a9819f4b728ba20268a8f7c
MD5 da6a8f5e5c88b4a4f0928cb7fa2b0560
BLAKE2b-256 4a57a0c368dfc14c9c0ca0d61782957a6c9dab661eadf4a67b06ca22ac01029e

See more details on using hashes here.

Provenance

The following attestation bundles were made for modelbatch-0.1.0.tar.gz:

Publisher: publish.yml on Rock-Z/ModelBatch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file modelbatch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: modelbatch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 29.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for modelbatch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b17edaef56a878e2e78d512f0a2886c5fad5be0b0bc5a07e2ec7a162cfec5633
MD5 f79e12c97b2904cf7b40d6f599d74a6a
BLAKE2b-256 afbbe54d8e7cf5fbb1d396c0085044869e4619cd0ab26403a42ed172b1681f9c

See more details on using hashes here.

Provenance

The following attestation bundles were made for modelbatch-0.1.0-py3-none-any.whl:

Publisher: publish.yml on Rock-Z/ModelBatch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page