Automated deep learning for computer vision — train image classification, object detection, and segmentation models with 1000+ backbones in minimal code
Project description
Train state-of-the-art vision models with minimal code.
Production-ready computer vision framework powered by timm and PyTorch Lightning.
Overview · Features · Quick Start · Tasks · Architecture · Docs · Contributing · License
Overview
AutoTimm is a production-ready computer vision framework that combines timm (1000+ pretrained models) with PyTorch Lightning. Train image classifiers, object detectors, and segmentation models with any timm backbone using a simple, intuitive API.
Whether you're training an image classifier, an object detector, or a segmentation model, AutoTimm handles:
- Automated model training with auto-tuning of learning rate and batch size
- 4 vision tasks — classification (single & multi-label), object detection, semantic & instance segmentation
- 1000+ backbones — ResNet, EfficientNet, ViT, ConvNeXt, Swin, and more via timm + HuggingFace Hub
- Model interpretation — GradCAM, Integrated Gradients, Attention Rollout, and more
- Production export — TorchScript and ONNX for deployment anywhere
From prototype to production in minutes, not hours.
How It Compares
| Manual PyTorch | AutoTimm | |
|---|---|---|
| Boilerplate | Hundreds of lines per task | 5–10 lines for any task |
| Backbone variety | Wire up manually | 1000+ backbones, swap with one arg |
| Auto-tuning | Implement yourself | LR + batch size finding built-in |
| Experiment tracking | External tooling needed | Multi-logger (TensorBoard, MLflow, W&B, CSV) |
| Reproducibility | Manual seed management | Opt-in deterministic mode (seed=42) |
| Model export | Custom export scripts | One-line TorchScript & ONNX |
| Interpretation | Separate libraries | 6 methods + metrics built-in |
Features
Vision Tasks
|
Smart Training
|
Interpretation & Export
|
Data & Integration
|
Who Is It For?Researchers needing reproducible experiments · Engineers building production ML systems · Students learning computer vision · Startups rapidly prototyping vision applications |
Quick Start
Installation
pip install autotimm
Everything included: PyTorch, timm, PyTorch Lightning, torchmetrics, albumentations, and more.
Optional extras
pip install autotimm[tensorboard] # TensorBoard
pip install autotimm[wandb] # Weights & Biases
pip install autotimm[mlflow] # MLflow
pip install autotimm[onnx] # ONNX export (onnx + onnxruntime + onnxscript)
pip install autotimm[all] # Everything
Your First Model
import autotimm as at # recommended alias
from autotimm import AutoTrainer, ImageClassifier, ImageDataModule, MetricConfig
# Data
data = ImageDataModule(
data_dir="./data",
dataset_name="CIFAR10",
image_size=224,
batch_size=64,
)
# Metrics
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
prog_bar=True,
)
]
# Model — try efficientnet_b0, vit_base_patch16_224, hf-hub:timm/resnet50.a1_in1k, etc.
model = ImageClassifier(backbone="resnet18", num_classes=10, metrics=metrics, lr=1e-3)
# Train with auto-tuning (finds optimal LR and batch size automatically)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
Auto-tuning is enabled by default. Disable with
tuner_config=Falsefor manual control.
Command-Line Interface
Train from YAML configs — no Python scripts needed:
autotimm fit --config config.yaml
autotimm test --config config.yaml --ckpt_path best.ckpt
model:
class_path: autotimm.ImageClassifier
init_args:
backbone: resnet18
num_classes: 10
data:
class_path: autotimm.ImageDataModule
init_args:
dataset_name: CIFAR10
data_dir: ./data
batch_size: 32
image_size: 224
trainer:
max_epochs: 10
accelerator: auto
Task Examples
Image Classification
from autotimm import ImageClassifier
model = ImageClassifier(
backbone="efficientnet_b0", # or "hf-hub:timm/resnet50.a1_in1k"
num_classes=10,
metrics=metrics,
)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
Object Detection with YOLOX
from autotimm import YOLOXDetector, DetectionDataModule
model = YOLOXDetector(
model_name="yolox-s", # nano, tiny, s, m, l, x
num_classes=80,
lr=0.01,
optimizer="sgd",
scheduler="yolox",
total_epochs=300,
)
trainer = AutoTrainer(max_epochs=300, precision="16-mixed")
trainer.fit(model, datamodule=DetectionDataModule(data_dir="./coco", image_size=640))
Semantic Segmentation
from autotimm import SemanticSegmentor, SegmentationDataModule
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
head_type="deeplabv3plus",
loss_type="combined", # CE + Dice for better boundaries
)
data = SegmentationDataModule(data_dir="./cityscapes", format="cityscapes", image_size=512)
trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=data)
Instance Segmentation
from autotimm import InstanceSegmentor, InstanceSegmentationDataModule
model = InstanceSegmentor(backbone="resnet50", num_classes=80, mask_loss_weight=1.0)
trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=InstanceSegmentationDataModule(data_dir="./coco"))
Multi-Label Classification
from autotimm import ImageClassifier, MultiLabelImageDataModule, MetricConfig
data = MultiLabelImageDataModule(
train_csv="train.csv", image_dir="./images", val_csv="val.csv",
image_size=224, batch_size=32,
)
data.setup("fit")
model = ImageClassifier(
backbone="resnet50",
num_classes=data.num_labels,
multi_label=True,
threshold=0.5,
metrics=[MetricConfig(
name="accuracy", backend="torchmetrics", metric_class="MultilabelAccuracy",
params={"num_labels": data.num_labels}, stages=["train", "val"], prog_bar=True,
)],
)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
Model Interpretation
from autotimm.interpretation import (
GradCAM, GradCAMPlusPlus, IntegratedGradients,
SmoothGrad, AttentionRollout, AttentionFlow,
ExplanationMetrics, InteractiveVisualizer,
)
# Explain a prediction
explainer = GradCAM(model)
heatmap = explainer.explain(image, target_class=5)
explainer.visualize(image, heatmap, save_path="gradcam.png")
# Quantitative evaluation
metrics = ExplanationMetrics(model, explainer)
deletion = metrics.deletion(image, target_class=5, steps=50)
insertion = metrics.insertion(image, target_class=5, steps=50)
# Interactive comparison
viz = InteractiveVisualizer(model)
viz.compare_methods(image, {
'GradCAM': GradCAM(model),
'IntegratedGradients': IntegratedGradients(model),
}, save_path="comparison.html")
Model Export
from autotimm import ImageClassifier, export_to_torchscript, export_to_onnx
import torch
model = ImageClassifier.load_from_checkpoint("model.ckpt")
example_input = torch.randn(1, 3, 224, 224)
# TorchScript
export_to_torchscript(model, "model.pt", example_input=example_input)
# ONNX
export_to_onnx(model, "model.onnx", example_input=example_input)
| Format | Use Case | Runtimes |
|---|---|---|
| TorchScript | PyTorch ecosystem, C++, mobile | LibTorch, PyTorch Mobile |
| ONNX | Cross-platform, hardware-optimized | ONNX Runtime, TensorRT, OpenVINO, CoreML |
Loading from Checkpoints
All task classes inherit PyTorch Lightning's load_from_checkpoint. Most parameters are auto-restored from saved hyperparameters — only ignored parameters (non-serializable objects) need to be re-supplied.
import autotimm as at
# Basic — works for all tasks (ignored params default to None)
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt")
# For inference on CPU
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", map_location="cpu")
model.eval()
# Override any saved hyperparameter
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", num_classes=20)
# Disable torch.compile for faster loading (recommended for inference/export)
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", compile_model=False)
Ignored parameters by task (not saved in checkpoint, re-supply if needed)
| Task | Ignored Parameters |
|---|---|
| ImageClassifier | metrics, logging_config, transform_config, loss_fn |
| ObjectDetector | metrics, logging_config, transform_config, cls_loss_fn, reg_loss_fn |
| SemanticSegmentor | metrics, logging_config, transform_config, class_weights, loss_fn |
| InstanceSegmentor | metrics, logging_config, transform_config, cls_loss_fn, reg_loss_fn, mask_loss_fn |
| YOLOXDetector | metrics, logging_config, transform_config |
All ignored parameters default to None, so load_from_checkpoint works without passing them. Re-supply only if you need custom losses or metrics for continued training.
# Resume training with custom metrics and loss
model = at.ObjectDetector.load_from_checkpoint(
"checkpoint.ckpt",
metrics=[at.MetricConfig(name="map", backend="torchmetrics", metric_class="MeanAveragePrecision")],
cls_loss_fn="focal",
reg_loss_fn="giou",
)
Note:
backbone,num_classes, and all other constructor arguments are saved as hyperparameters and restored automatically — no need to re-specify them.
Architecture
graph TD
subgraph API["Public API · autotimm"]
Init["__init__.py"] --> Tasks
Init --> Data
Init --> Interpret
Init --> Export
end
subgraph Core["Core · autotimm.core"]
Backbone["backbone.py · 1000+ timm models"]
Metrics["metrics.py · torchmetrics"]
Loggers["loggers.py · TB / MLflow / W&B / CSV"]
Utils["utils.py · seeding, optimizers"]
Logging["logging.py · loguru"]
end
subgraph Tasks["Tasks · autotimm.tasks"]
Cls["ImageClassifier"]
Det["ObjectDetector"]
Seg["SemanticSegmentor"]
Inst["InstanceSegmentor"]
YOLOX["YOLOXDetector"]
end
subgraph Data["Data · autotimm.data"]
DM["DataModules"] --> DS["Datasets"]
DM --> TX["Transforms"]
end
subgraph Interpret["Interpretation · autotimm.interpretation"]
Methods["6 Methods"] --> Viz["Visualization"]
Methods --> QM["Quality Metrics"]
end
subgraph Export["Export · autotimm.export"]
JIT["TorchScript"]
ONNX["ONNX"]
end
Tasks --> Core
Data --> Core
Package Structure
| Module | Purpose |
|---|---|
autotimm |
Public API — all exports via __init__.py |
autotimm.core |
Backbone factory, metrics, loggers, logging, utilities |
autotimm.tasks |
Task models (LightningModule subclasses) |
autotimm.data |
DataModules, datasets, and transform pipelines |
autotimm.heads |
Task-specific prediction heads (Classification, Detection, FPN, DeepLabV3+, Mask, YOLOX) |
autotimm.losses |
Loss functions and registry (Focal, GIoU, Dice, Tversky, etc.) |
autotimm.models |
YOLOX-specific components (CSPDarknet, PAFPN) |
autotimm.interpretation |
Explanation methods, metrics, visualization, callbacks |
autotimm.export |
TorchScript and ONNX export utilities + CLIs |
autotimm.training |
AutoTrainer wrapper with auto-tuning |
autotimm.cli |
Command-line interface (fit/test/validate + interpretation CLI) |
autotimm.callbacks |
JSON progress callback for frontend integration |
Documentation & Examples
| Section | Description |
|---|---|
| Quick Start | Get up and running in 5 minutes |
| User Guide | In-depth guides for all features |
| Interpretation Guide | Model explainability and visualization |
| YOLOX Guide | Complete YOLOX implementation guide |
| CLI Guide | Train from YAML configs on the command line |
| API Reference | Complete API documentation |
| Examples | 50+ runnable code examples |
Ready-to-Run Examples
Getting Started — classify_cifar10.py · classify_custom_folder.py · vit_finetuning.py
Computer Vision — yolox_official.py · object_detection_yolox.py · semantic_segmentation.py · instance_segmentation.py
HuggingFace Hub — huggingface_hub_models.py · hf_interpretation.py · hf_transfer_learning.py · hf_ensemble.py · hf_deployment.py
Data & Training — csv_classification.py · csv_detection.py · multilabel_classification.py · multi_gpu_training.py
Interpretation — interpretation_demo.py · interpretation_metrics_demo.py · interactive_visualization_demo.py
CLI Configs — classification.yaml · detection.yaml · segmentation.yaml
Testing
# Run all tests
pytest tests/ -v
# Specific modules
pytest tests/test_classification.py
pytest tests/test_yolox.py
pytest tests/test_interpretation.py
# With coverage
pytest tests/ --cov=autotimm --cov-report=html
Contributing
Contributions, issues, and feature requests are welcome! See the issues page to get started.
- Fork the repository
- Create a feature branch —
git checkout -b feat/your-feature - Commit your changes —
git commit -m 'feat: add your feature' - Push to the branch —
git push origin feat/your-feature - Open a Pull Request
git clone https://github.com/theja-vanka/AutoTimm.git && cd AutoTimm
pip install -e ".[dev,all]"
pytest tests/ -v
Citation
@software{autotimm,
author = {Krishnatheja Vanka},
title = {AutoTimm: Automatic PyTorch Image Models},
url = {https://github.com/theja-vanka/AutoTimm},
year = {2026},
}
License
Distributed under the Apache License 2.0. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file autotimm-0.7.34.tar.gz.
File metadata
- Download URL: autotimm-0.7.34.tar.gz
- Upload date:
- Size: 4.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
812815f94dd98b262b9c8b3aa6e38978a50e514895cb959117226fed2bce0021
|
|
| MD5 |
d0e7bf2a7dc8b5e749410c07c658a829
|
|
| BLAKE2b-256 |
48f8bd552b7351ec82b3eb5be0d7e5dac58d1805fb7b1420cb6bfe361afdb2ec
|
File details
Details for the file autotimm-0.7.34-py3-none-any.whl.
File metadata
- Download URL: autotimm-0.7.34-py3-none-any.whl
- Upload date:
- Size: 224.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a2ea6f241d39d61b1b7fb829bd609615ce46f38774b5f484aa5cdbe5f1facb1e
|
|
| MD5 |
c0956651c3bb7e0723ece0f1f248b9d4
|
|
| BLAKE2b-256 |
715a11f04befec6be34b4e4454545fe56a32875b85e5fbfb34b54549c20a77a4
|