Skip to main content

Automated deep learning for computer vision — train image classification, object detection, and segmentation models with 1000+ backbones in minimal code

Project description


AutoTimm logo

Train state-of-the-art vision models with minimal code.

Production-ready computer vision framework powered by timm and PyTorch Lightning.


PyPI License GitHub Stars GitHub Issues GitHub PRs Last Commit


Python PyTorch Lightning


timm HuggingFace torchmetrics


Overview · Features · Quick Start · Tasks · Architecture · Docs · Contributing · License


Overview

AutoTimm is a production-ready computer vision framework that combines timm (1000+ pretrained models) with PyTorch Lightning. Train image classifiers, object detectors, and segmentation models with any timm backbone using a simple, intuitive API.

Whether you're training an image classifier, an object detector, or a segmentation model, AutoTimm handles:

  • Automated model training with auto-tuning of learning rate and batch size
  • 4 vision tasks — classification (single & multi-label), object detection, semantic & instance segmentation
  • 1000+ backbones — ResNet, EfficientNet, ViT, ConvNeXt, Swin, and more via timm + HuggingFace Hub
  • Model interpretation — GradCAM, Integrated Gradients, Attention Rollout, and more
  • Production export — TorchScript and ONNX for deployment anywhere

From prototype to production in minutes, not hours.


How It Compares

Manual PyTorch AutoTimm
Boilerplate Hundreds of lines per task 5–10 lines for any task
Backbone variety Wire up manually 1000+ backbones, swap with one arg
Auto-tuning Implement yourself LR + batch size finding built-in
Experiment tracking External tooling needed Multi-logger (TensorBoard, MLflow, W&B, CSV)
Reproducibility Manual seed management Opt-in deterministic mode (seed=42)
Model export Custom export scripts One-line TorchScript & ONNX
Interpretation Separate libraries 6 methods + metrics built-in

Features

Vision Tasks

  • Image Classification — single-label and multi-label with any timm backbone
  • Object Detection — FCOS and YOLOX architectures (official + custom)
  • Semantic Segmentation — DeepLabV3+ and FCN with combined losses
  • Instance Segmentation — FCOS + Mask R-CNN style mask head

Smart Training

  • Auto-tuning — automatic LR and batch size discovery
  • torch.compile — PyTorch 2.0+ optimization enabled by default
  • Reproducibility — deterministic seeding across Python, NumPy, PyTorch
  • Mixed precision — 16-bit and bf16 training out of the box
  • Multi-GPU — distributed training with zero config

Interpretation & Export

  • 6 explanation methods — GradCAM, GradCAM++, Integrated Gradients, SmoothGrad, Attention Rollout, Attention Flow
  • 6 quality metrics — faithfulness, sensitivity, localization, sanity checks
  • Interactive visualizations — Plotly-powered HTML reports
  • Model export — TorchScript (.pt) and ONNX (.onnx) for production

Data & Integration

  • Flexible data loading — folder structure, COCO JSON, CSV, HuggingFace datasets
  • Smart transforms — AI-powered backend selection, unified TransformConfig
  • HuggingFace Hub — load models with hf-hub: prefix
  • Multi-logger — TensorBoard, MLflow, W&B, CSV simultaneously
  • CLI — YAML-driven training from the command line

Who Is It For?

Researchers needing reproducible experiments · Engineers building production ML systems · Students learning computer vision · Startups rapidly prototyping vision applications


Quick Start

Installation

pip install autotimm

Everything included: PyTorch, timm, PyTorch Lightning, torchmetrics, albumentations, and more.

Optional extras
pip install autotimm[tensorboard]  # TensorBoard
pip install autotimm[wandb]        # Weights & Biases
pip install autotimm[mlflow]       # MLflow
pip install autotimm[onnx]         # ONNX export (onnx + onnxruntime + onnxscript)
pip install autotimm[all]          # Everything

Your First Model

import autotimm as at  # recommended alias
from autotimm import AutoTrainer, ImageClassifier, ImageDataModule, MetricConfig

# Data
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    image_size=224,
    batch_size=64,
)

# Metrics
metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
        prog_bar=True,
    )
]

# Model — try efficientnet_b0, vit_base_patch16_224, hf-hub:timm/resnet50.a1_in1k, etc.
model = ImageClassifier(backbone="resnet18", num_classes=10, metrics=metrics, lr=1e-3)

# Train with auto-tuning (finds optimal LR and batch size automatically)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

Auto-tuning is enabled by default. Disable with tuner_config=False for manual control.

Command-Line Interface

Train from YAML configs — no Python scripts needed:

autotimm fit --config config.yaml
autotimm test --config config.yaml --ckpt_path best.ckpt
model:
  class_path: autotimm.ImageClassifier
  init_args:
    backbone: resnet18
    num_classes: 10

data:
  class_path: autotimm.ImageDataModule
  init_args:
    dataset_name: CIFAR10
    data_dir: ./data
    batch_size: 32
    image_size: 224

trainer:
  max_epochs: 10
  accelerator: auto

Task Examples

Image Classification

from autotimm import ImageClassifier

model = ImageClassifier(
    backbone="efficientnet_b0",  # or "hf-hub:timm/resnet50.a1_in1k"
    num_classes=10,
    metrics=metrics,
)

trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

Object Detection with YOLOX

from autotimm import YOLOXDetector, DetectionDataModule

model = YOLOXDetector(
    model_name="yolox-s",  # nano, tiny, s, m, l, x
    num_classes=80,
    lr=0.01,
    optimizer="sgd",
    scheduler="yolox",
    total_epochs=300,
)

trainer = AutoTrainer(max_epochs=300, precision="16-mixed")
trainer.fit(model, datamodule=DetectionDataModule(data_dir="./coco", image_size=640))

Semantic Segmentation

from autotimm import SemanticSegmentor, SegmentationDataModule

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    loss_type="combined",  # CE + Dice for better boundaries
)

data = SegmentationDataModule(data_dir="./cityscapes", format="cityscapes", image_size=512)

trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=data)

Instance Segmentation

from autotimm import InstanceSegmentor, InstanceSegmentationDataModule

model = InstanceSegmentor(backbone="resnet50", num_classes=80, mask_loss_weight=1.0)

trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=InstanceSegmentationDataModule(data_dir="./coco"))

Multi-Label Classification

from autotimm import ImageClassifier, MultiLabelImageDataModule, MetricConfig

data = MultiLabelImageDataModule(
    train_csv="train.csv", image_dir="./images", val_csv="val.csv",
    image_size=224, batch_size=32,
)
data.setup("fit")

model = ImageClassifier(
    backbone="resnet50",
    num_classes=data.num_labels,
    multi_label=True,
    threshold=0.5,
    metrics=[MetricConfig(
        name="accuracy", backend="torchmetrics", metric_class="MultilabelAccuracy",
        params={"num_labels": data.num_labels}, stages=["train", "val"], prog_bar=True,
    )],
)

trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

Model Interpretation

from autotimm.interpretation import (
    GradCAM, GradCAMPlusPlus, IntegratedGradients,
    SmoothGrad, AttentionRollout, AttentionFlow,
    ExplanationMetrics, InteractiveVisualizer,
)

# Explain a prediction
explainer = GradCAM(model)
heatmap = explainer.explain(image, target_class=5)
explainer.visualize(image, heatmap, save_path="gradcam.png")

# Quantitative evaluation
metrics = ExplanationMetrics(model, explainer)
deletion = metrics.deletion(image, target_class=5, steps=50)
insertion = metrics.insertion(image, target_class=5, steps=50)

# Interactive comparison
viz = InteractiveVisualizer(model)
viz.compare_methods(image, {
    'GradCAM': GradCAM(model),
    'IntegratedGradients': IntegratedGradients(model),
}, save_path="comparison.html")

Model Export

from autotimm import ImageClassifier, export_to_torchscript, export_to_onnx
import torch

model = ImageClassifier.load_from_checkpoint("model.ckpt")
example_input = torch.randn(1, 3, 224, 224)

# TorchScript
export_to_torchscript(model, "model.pt", example_input=example_input)

# ONNX
export_to_onnx(model, "model.onnx", example_input=example_input)
Format Use Case Runtimes
TorchScript PyTorch ecosystem, C++, mobile LibTorch, PyTorch Mobile
ONNX Cross-platform, hardware-optimized ONNX Runtime, TensorRT, OpenVINO, CoreML

Loading from Checkpoints

All task classes inherit PyTorch Lightning's load_from_checkpoint. Most parameters are auto-restored from saved hyperparameters — only ignored parameters (non-serializable objects) need to be re-supplied.

import autotimm as at

# Basic — works for all tasks (ignored params default to None)
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt")

# For inference on CPU
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", map_location="cpu")
model.eval()

# Override any saved hyperparameter
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", num_classes=20)

# Disable torch.compile for faster loading (recommended for inference/export)
model = at.ImageClassifier.load_from_checkpoint("checkpoint.ckpt", compile_model=False)
Ignored parameters by task (not saved in checkpoint, re-supply if needed)
Task Ignored Parameters
ImageClassifier metrics, logging_config, transform_config, loss_fn
ObjectDetector metrics, logging_config, transform_config, cls_loss_fn, reg_loss_fn
SemanticSegmentor metrics, logging_config, transform_config, class_weights, loss_fn
InstanceSegmentor metrics, logging_config, transform_config, cls_loss_fn, reg_loss_fn, mask_loss_fn
YOLOXDetector metrics, logging_config, transform_config

All ignored parameters default to None, so load_from_checkpoint works without passing them. Re-supply only if you need custom losses or metrics for continued training.

# Resume training with custom metrics and loss
model = at.ObjectDetector.load_from_checkpoint(
    "checkpoint.ckpt",
    metrics=[at.MetricConfig(name="map", backend="torchmetrics", metric_class="MeanAveragePrecision")],
    cls_loss_fn="focal",
    reg_loss_fn="giou",
)

Note: backbone, num_classes, and all other constructor arguments are saved as hyperparameters and restored automatically — no need to re-specify them.


Architecture

graph TD
    subgraph API["Public API · autotimm"]
        Init["__init__.py"] --> Tasks
        Init --> Data
        Init --> Interpret
        Init --> Export
    end

    subgraph Core["Core · autotimm.core"]
        Backbone["backbone.py · 1000+ timm models"]
        Metrics["metrics.py · torchmetrics"]
        Loggers["loggers.py · TB / MLflow / W&B / CSV"]
        Utils["utils.py · seeding, optimizers"]
        Logging["logging.py · loguru"]
    end

    subgraph Tasks["Tasks · autotimm.tasks"]
        Cls["ImageClassifier"]
        Det["ObjectDetector"]
        Seg["SemanticSegmentor"]
        Inst["InstanceSegmentor"]
        YOLOX["YOLOXDetector"]
    end

    subgraph Data["Data · autotimm.data"]
        DM["DataModules"] --> DS["Datasets"]
        DM --> TX["Transforms"]
    end

    subgraph Interpret["Interpretation · autotimm.interpretation"]
        Methods["6 Methods"] --> Viz["Visualization"]
        Methods --> QM["Quality Metrics"]
    end

    subgraph Export["Export · autotimm.export"]
        JIT["TorchScript"]
        ONNX["ONNX"]
    end

    Tasks --> Core
    Data --> Core
Package Structure
Module Purpose
autotimm Public API — all exports via __init__.py
autotimm.core Backbone factory, metrics, loggers, logging, utilities
autotimm.tasks Task models (LightningModule subclasses)
autotimm.data DataModules, datasets, and transform pipelines
autotimm.heads Task-specific prediction heads (Classification, Detection, FPN, DeepLabV3+, Mask, YOLOX)
autotimm.losses Loss functions and registry (Focal, GIoU, Dice, Tversky, etc.)
autotimm.models YOLOX-specific components (CSPDarknet, PAFPN)
autotimm.interpretation Explanation methods, metrics, visualization, callbacks
autotimm.export TorchScript and ONNX export utilities + CLIs
autotimm.training AutoTrainer wrapper with auto-tuning
autotimm.cli Command-line interface (fit/test/validate + interpretation CLI)
autotimm.callbacks JSON progress callback for frontend integration

Documentation & Examples

Section Description
Quick Start Get up and running in 5 minutes
User Guide In-depth guides for all features
Interpretation Guide Model explainability and visualization
YOLOX Guide Complete YOLOX implementation guide
CLI Guide Train from YAML configs on the command line
API Reference Complete API documentation
Examples 50+ runnable code examples
Ready-to-Run Examples

Getting Startedclassify_cifar10.py · classify_custom_folder.py · vit_finetuning.py

Computer Visionyolox_official.py · object_detection_yolox.py · semantic_segmentation.py · instance_segmentation.py

HuggingFace Hubhuggingface_hub_models.py · hf_interpretation.py · hf_transfer_learning.py · hf_ensemble.py · hf_deployment.py

Data & Trainingcsv_classification.py · csv_detection.py · multilabel_classification.py · multi_gpu_training.py

Interpretationinterpretation_demo.py · interpretation_metrics_demo.py · interactive_visualization_demo.py

CLI Configsclassification.yaml · detection.yaml · segmentation.yaml

Browse all examples


Testing

# Run all tests
pytest tests/ -v

# Specific modules
pytest tests/test_classification.py
pytest tests/test_yolox.py
pytest tests/test_interpretation.py

# With coverage
pytest tests/ --cov=autotimm --cov-report=html

Contributing

Contributions, issues, and feature requests are welcome! See the issues page to get started.

  1. Fork the repository
  2. Create a feature branch — git checkout -b feat/your-feature
  3. Commit your changes — git commit -m 'feat: add your feature'
  4. Push to the branch — git push origin feat/your-feature
  5. Open a Pull Request
git clone https://github.com/theja-vanka/AutoTimm.git && cd AutoTimm
pip install -e ".[dev,all]"
pytest tests/ -v

Citation

@software{autotimm,
  author = {Krishnatheja Vanka},
  title = {AutoTimm: Automatic PyTorch Image Models},
  url = {https://github.com/theja-vanka/AutoTimm},
  year = {2026},
}

License

Distributed under the Apache License 2.0. See LICENSE for details.



Built with care by Krishnatheja Vanka

If AutoTimm saves you time, consider giving it a ⭐


Buy Me A Coffee


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autotimm-0.7.34.tar.gz (4.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autotimm-0.7.34-py3-none-any.whl (224.8 kB view details)

Uploaded Python 3

File details

Details for the file autotimm-0.7.34.tar.gz.

File metadata

  • Download URL: autotimm-0.7.34.tar.gz
  • Upload date:
  • Size: 4.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for autotimm-0.7.34.tar.gz
Algorithm Hash digest
SHA256 812815f94dd98b262b9c8b3aa6e38978a50e514895cb959117226fed2bce0021
MD5 d0e7bf2a7dc8b5e749410c07c658a829
BLAKE2b-256 48f8bd552b7351ec82b3eb5be0d7e5dac58d1805fb7b1420cb6bfe361afdb2ec

See more details on using hashes here.

File details

Details for the file autotimm-0.7.34-py3-none-any.whl.

File metadata

  • Download URL: autotimm-0.7.34-py3-none-any.whl
  • Upload date:
  • Size: 224.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for autotimm-0.7.34-py3-none-any.whl
Algorithm Hash digest
SHA256 a2ea6f241d39d61b1b7fb829bd609615ce46f38774b5f484aa5cdbe5f1facb1e
MD5 c0956651c3bb7e0723ece0f1f248b9d4
BLAKE2b-256 715a11f04befec6be34b4e4454545fe56a32875b85e5fbfb34b54549c20a77a4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page