A PyTorch utility library that streamlines the deep learning training pipeline.
Project description
torchaid
torchaid is a PyTorch utility library that provides structured abstractions and reusable components to streamline the deep learning training pipeline.
Features
- Structured training abstractions — base classes for metrics and settings built on Pydantic v2
- Training framework — a full training loop with mixed-precision support, automatic checkpointing, metric logging (CSV), and early stopping
- Transformer modules — standard and relative-position-aware Transformer encoder layers with multi-head self-attention
- Task templates — ready-to-use implementation for multi-label classification
- Utilities — dataset splitting, random seed management, attention mask generation, and JSON-to-Pydantic loading
- Learning rate schedulers — cosine decay with linear warm-up, and triangular2 cyclic scheduling
Requirements
- Python 3.10+
- PyTorch 2.0+
Installation
pip install torchaid
Or install from source:
git clone https://github.com/harunori-kawano/torchaid.git
cd torchaid
pip install -e .
Quick Start
1. Implement your model
Batches are plain dict[str, Any]. forward returns a (outputs, error) tuple — set error to a non-None value to signal a recoverable per-batch error; the framework will skip backpropagation and log it to stderr.
from torchaid import TaskModule, Mode
from typing import Any, Optional
from torch import nn
class MyModel(TaskModule):
def __init__(self, vocab_size: int, num_classes: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, 128)
self.classifier = nn.Linear(128, num_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, mode: Mode, batch: dict[str, Any]) -> tuple[dict[str, Any], Optional[Any]]:
x = self.embed(batch["input_ids"]).mean(dim=1)
logits = self.classifier(x)
loss = self.criterion(logits, batch["labels"])
if mode == Mode.TRAIN:
return {"loss": loss}, None
return {"loss": loss, "logits": logits}, None
2. Define metrics and settings
from torchaid import BaseMetrics, BaseSettings, BaseMetricCalculator
from typing import Optional, Any
class MyMetrics(BaseMetrics):
train_loss: Optional[float] = None
val_loss: Optional[float] = None
class MySettings(BaseSettings):
batch_size: int = 32
max_epoch_num: int = 10
class MyCalculator(BaseMetricCalculator[MyMetrics]):
def __init__(self):
super().__init__(MyMetrics())
self._losses: list[float] = []
def train_step(self, outputs: dict[str, Any], batch: dict[str, Any]) -> dict[str, Any]:
loss = outputs["loss"].item()
self._losses.append(loss)
return {"loss": loss}
def val_step(self, outputs: dict[str, Any], batch: dict[str, Any]) -> dict[str, Any]:
return self.train_step(outputs, batch)
def test_step(self, outputs: dict[str, Any], batch: dict[str, Any]) -> dict[str, Any]:
return self.train_step(outputs, batch)
def check(self) -> bool:
import statistics
self.metrics.train_loss = statistics.mean(self._losses)
return True
def test(self): pass
def reset(self):
self._losses.clear()
3. Train
import torch
from torchaid.core.trainer import TrainFramework
settings = MySettings(batch_size=32, max_epoch_num=10, device="cuda")
model = MyModel(vocab_size=1000, num_classes=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
framework = TrainFramework(
model=model,
ls=settings,
metric_calculator=MyCalculator(),
optimizer=optimizer,
)
framework.train(train_dataset, val_dataset, save_dir="./outputs")
Module Overview
| Module | Description |
|---|---|
torchaid.core |
Base classes (BaseMetrics, BaseSettings, TaskModule, BaseMetricCalculator, Mode) and TrainFramework |
torchaid.templates.multilabel_classification |
Complete template for multi-label classification |
torchaid.extras.modules.transformer |
Transformer, TransformerWithRelativePosition, and sub-modules |
torchaid.extras.modules.positional_encoders |
PositionalEmbedding, RelativePositionEmbedding |
torchaid.extras.utils |
split_dataset, set_random_seed, make_attention_mask, json_to_instance |
torchaid.extras.scheduler |
get_cosine_scheduler, get_cycle_scheduler |
Template: Multi-Label Classification
from torchaid.templates import multilabel_classification as mlc
from torchaid.core.trainer import TrainFramework
from torch import nn
import torch
backbone = nn.Sequential(nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10))
model = mlc.MultiLabelClassification(backbone)
optimizer = torch.optim.Adam(model.parameters())
framework = TrainFramework(
model=model,
ls=settings,
metric_calculator=mlc.MetricsCalculator(),
optimizer=optimizer,
)
Extras
Cosine Decay Scheduler
from torchaid.extras.scheduler import get_cosine_scheduler
scheduler = get_cosine_scheduler(
optimizer, warmup_steps=500, max_steps=10000
)
Dataset Split
from torchaid.extras.utils import split_dataset
train, val, test = split_dataset(dataset, ratios=[8, 1, 1], seed=42)
Relative Position Transformer
from torchaid.extras.modules.transformer import TransformerWithRelativePosition
layer = TransformerWithRelativePosition(
hidden_size=256,
intermediate_size=1024,
num_attention_heads=8,
dropout_probability=0.1,
max_length=512,
with_cls=True,
)
License
MIT License. See LICENSE for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchaid-0.1.6.tar.gz.
File metadata
- Download URL: torchaid-0.1.6.tar.gz
- Upload date:
- Size: 26.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
440a0f0b24f0152e2681b944b3517ea938b6504cd87628edc13a5dce4c3e8773
|
|
| MD5 |
0355fc5fd4f0913d339bed248d172925
|
|
| BLAKE2b-256 |
d7830c57d9f9e23aed4533422986c97b5bcba254027db0b6cf3c25440ffe4e5d
|
File details
Details for the file torchaid-0.1.6-py3-none-any.whl.
File metadata
- Download URL: torchaid-0.1.6-py3-none-any.whl
- Upload date:
- Size: 29.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4b10b1bcef705eeb4a9b2c9a22fa950d7d4a40789ddd46ef83f648d17a96cc48
|
|
| MD5 |
132259d45a2bfff82956d25386779d32
|
|
| BLAKE2b-256 |
c72cc4ab65646ec694347d46422f035911be36e8639c65d5a11387f1280e3a68
|