Skip to main content

Automatic Pytorch Image Models

Project description

AutoTimm

Train state-of-the-art vision models with minimal code
From prototype to production in minutes, not hours

PyPI Python License Stars

DocumentationQuick StartExamplesAPI Reference


What is AutoTimm?

AutoTimm is a production-ready computer vision framework that combines timm (1000+ pretrained models) with PyTorch Lightning. Train image classifiers, object detectors, and segmentation models with any timm backbone using a simple, intuitive API.

Perfect for:

  • Researchers needing reproducible experiments and quick iterations
  • Engineers building production ML systems with minimal boilerplate
  • Students learning computer vision with modern best practices
  • Startups rapidly prototyping vision applications

Why AutoTimm?

Fast

From idea to trained model in minutes. Auto-tuning, mixed precision, and multi-GPU out of the box.

Flexible

1000+ backbones, 4 vision tasks, multiple transform backends. Use what works best.

Production Ready

200+ tests, comprehensive logging, checkpoint management. Deploy with confidence.

What's New in v0.7.1

  • Model Interpretation - Complete explainability toolkit with 6 interpretation methods, 6 quality metrics, interactive Plotly visualizations, and up to 100x speedup with optimization
  • Tutorial Notebook - Comprehensive Jupyter notebook covering all interpretation features end-to-end
  • YOLOX Models - Official YOLOX implementation (nano to X) with CSPDarknet backbone
  • Smart Backend Selection - AI-powered recommendation for optimal transform backends
  • TransformConfig - Unified transform configuration with presets and model-specific normalization
  • Optional Metrics - Metrics now optional for inference-only deployments
  • Python 3.10-3.14 - Latest Python support

Quick Start

Installation

pip install autotimm

Everything included: PyTorch, timm, PyTorch Lightning, torchmetrics, albumentations, pycocotools, and more.

Optional extras
# Logging backends
pip install autotimm[tensorboard]  # TensorBoard
pip install autotimm[wandb]        # Weights & Biases
pip install autotimm[mlflow]       # MLflow

# Interpretation
pip install autotimm[interactive]  # Interactive Plotly visualizations

# All extras
pip install autotimm[all]          # Everything

Your First Model in 30 Seconds

from autotimm import AutoTrainer, ImageClassifier, ImageDataModule, MetricConfig

# Data
data = ImageDataModule(
    data_dir="./data",
    dataset_name="CIFAR10",
    image_size=224,
    batch_size=64,
)

# Metrics
metrics = [
    MetricConfig(
        name="accuracy",
        backend="torchmetrics",
        metric_class="Accuracy",
        params={"task": "multiclass"},
        stages=["train", "val", "test"],
        prog_bar=True,
    )
]

# Model
model = ImageClassifier(
    backbone="resnet18",  # Try efficientnet_b0, vit_base_patch16_224, etc.
    num_classes=10,
    metrics=metrics,
    lr=1e-3,
)

# Train with auto-tuning (finds optimal LR and batch size automatically!)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

Auto-tuning is enabled by default. Disable with tuner_config=False for manual control.

Key Features

4 Vision Tasks Classification • Object Detection • Semantic Segmentation • Instance Segmentation
1000+ Backbones ResNet • EfficientNet • ViT • ConvNeXt • Swin • DeiT • BEiT • and more from timm
Model Interpretation 6 explanation methods • 6 quality metrics • Interactive visualizations • Up to 100x speedup
HuggingFace Integration Load models from HF Hub with hf-hub: prefix + Direct Transformers support
YOLOX Support Official YOLOX models (nano → X) + YOLOX-style heads with any timm backbone
Advanced Architectures DeepLabV3+ • FCOS • YOLOX • Mask R-CNN • Feature Pyramids
Auto-Tuning Automatic LR and batch size finding—enabled by default
Smart Transforms AI-powered backend recommendations + unified TransformConfig with presets
Multi-Logger Support TensorBoard • MLflow • Weights & Biases • CSV—use simultaneously
Production Ready Mixed precision • Multi-GPU • Gradient accumulation • 200+ tests

Task Examples

Image Classification

from autotimm import ImageClassifier

# Use any timm backbone or HuggingFace model
model = ImageClassifier(
    backbone="efficientnet_b0",  # or "hf-hub:timm/resnet50.a1_in1k"
    num_classes=10,
    metrics=metrics,  # Optional for inference!
)

trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)

Object Detection with YOLOX

Official YOLOX (matches paper benchmarks):

from autotimm import YOLOXDetector, DetectionDataModule

model = YOLOXDetector(
    model_name="yolox-s",  # nano, tiny, s, m, l, x
    num_classes=80,
    lr=0.01,
    optimizer="sgd",
    scheduler="yolox",
    total_epochs=300,
)

trainer = AutoTrainer(max_epochs=300, precision="16-mixed")
trainer.fit(model, datamodule=DetectionDataModule(data_dir="./coco", image_size=640))

YOLOX-style head with any timm backbone:

from autotimm import ObjectDetector

model = ObjectDetector(
    backbone="resnet50",  # Experiment with any backbone!
    num_classes=80,
    detection_arch="yolox",
    fpn_channels=256,
)

Complete YOLOX GuideQuick Reference

Semantic Segmentation

from autotimm import SemanticSegmentor, SegmentationDataModule

model = SemanticSegmentor(
    backbone="resnet50",
    num_classes=19,
    head_type="deeplabv3plus",
    loss_type="combined",  # CE + Dice for better boundaries
)

data = SegmentationDataModule(
    data_dir="./cityscapes",
    format="cityscapes",  # or "coco", "voc", "png"
    image_size=512,
)

trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=data)

Instance Segmentation

from autotimm import InstanceSegmentor, InstanceSegmentationDataModule

model = InstanceSegmentor(
    backbone="resnet50",
    num_classes=80,
    mask_loss_weight=1.0,
)

trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=InstanceSegmentationDataModule(data_dir="./coco"))

Model Interpretation & Explainability

Understand what your models learn and how they make decisions with comprehensive interpretation tools.

Quick Explanation

from autotimm.interpretation import quick_explain

# One-line explanation
result = quick_explain(
    model,
    image,
    method="gradcam",
    save_path="explanation.png"
)

6 Interpretation Methods

from autotimm.interpretation import (
    GradCAM,                # Fast, class-discriminative (CNNs)
    GradCAMPlusPlus,        # Better for multiple objects
    IntegratedGradients,    # Theoretically sound, pixel-level
    SmoothGrad,             # Noise-reduced gradients
    AttentionRollout,       # Vision Transformers
    AttentionFlow,          # Vision Transformers
)

# Use any method
explainer = GradCAM(model)
heatmap = explainer.explain(image, target_class=5)
explainer.visualize(image, heatmap, save_path="gradcam.png")

Quantitative Evaluation

from autotimm.interpretation import ExplanationMetrics

metrics = ExplanationMetrics(model, explainer)

# Faithfulness metrics
deletion = metrics.deletion(image, target_class=5, steps=50)
insertion = metrics.insertion(image, target_class=5, steps=50)

# Stability metric
sensitivity = metrics.sensitivity_n(image, n_samples=50)

# Sanity checks
param_test = metrics.model_parameter_randomization_test(image)
data_test = metrics.data_randomization_test(image)

# Localization metric
pointing = metrics.pointing_game(image, bbox=(50, 50, 150, 150))

print(f"Deletion AUC: {deletion['auc']:.4f}")  # Lower = better
print(f"Insertion AUC: {insertion['auc']:.4f}")  # Higher = better
print(f"Sensitivity: {sensitivity['sensitivity']:.4f}")  # Lower = more stable

Interactive Visualizations

from autotimm.interpretation import InteractiveVisualizer

viz = InteractiveVisualizer(model)

# Create interactive HTML with zoom/pan/hover
fig = viz.visualize_explanation(
    image,
    explainer,
    colorscale="Viridis",
    save_path="interactive.html"
)

# Compare methods side-by-side
explainers = {
    'GradCAM': GradCAM(model),
    'GradCAM++': GradCAMPlusPlus(model),
    'Integrated Gradients': IntegratedGradients(model),
}
viz.compare_methods(image, explainers, save_path="comparison.html")

# Generate comprehensive report
viz.create_report(image, explainer, save_path="report.html")

Performance Optimization

from autotimm.interpretation.optimization import (
    ExplanationCache,        # 10-50x speedup
    BatchProcessor,          # 2-5x speedup
    PerformanceProfiler,     # Identify bottlenecks
    optimize_for_inference,  # 1.5-3x speedup
)

# Enable caching
cache = ExplanationCache(cache_dir="./cache", max_size_mb=5000)

# Optimize model
model = optimize_for_inference(model, use_fp16=True)

# Batch processing
processor = BatchProcessor(model, explainer, batch_size=32)
heatmaps = processor.process_batch(images)

# Profile performance
profiler = PerformanceProfiler(enabled=True)
with profiler.profile("explanation"):
    heatmap = explainer.explain(image)
profiler.print_stats()

Training Integration

from autotimm import AutoTrainer
from autotimm.interpretation import InterpretationCallback

# Monitor interpretations during training
callback = InterpretationCallback(
    sample_images=val_images,
    method="gradcam",
    log_every_n_epochs=5,
)

trainer = AutoTrainer(
    max_epochs=100,
    callbacks=[callback],
    logger="tensorboard",
)
trainer.fit(model, datamodule=data)

Features:

  • 6 interpretation methods for different use cases
  • 6 quality metrics for quantitative evaluation
  • Interactive visualizations with Plotly (zoom/pan/hover)
  • Up to 100x speedup with caching and optimization
  • Feature visualization and receptive field analysis
  • Training callbacks for automatic monitoring
  • Comprehensive tutorial notebook included

Interpretation GuideTutorial Notebook

HuggingFace Integration

Three Approaches

Approach Best For Example
HF Hub timm CNNs, Production "hf-hub:timm/resnet50.a1_in1k"
HF Transformers Direct Vision Transformers ViTModel.from_pretrained(...)
HF Transformers Auto Quick Prototyping AutoModel.from_pretrained(...)

All approaches fully support AutoTrainer (checkpointing, early stopping, mixed precision, multi-GPU, auto-tuning).

from autotimm import ImageClassifier, list_hf_hub_backbones

# Discover models
models = list_hf_hub_backbones(model_name="resnet", limit=5)

# Use any HF Hub model (just add 'hf-hub:' prefix!)
model = ImageClassifier(
    backbone="hf-hub:timm/convnext_base.fb_in22k_ft_in1k",
    num_classes=100,
)

HF Integration ComparisonHF Hub GuideHF Transformers Guide

Smart Features

Smart Backend Selection

from autotimm import recommend_backend, compare_backends

# Get AI-powered recommendation
rec = recommend_backend(task="detection")
config = rec.to_config(image_size=640)

# Compare backends side-by-side
compare_backends()

Unified Transform Configuration

from autotimm import TransformConfig, list_transform_presets

# Discover presets
list_transform_presets()  # ['default', 'autoaugment', 'randaugment', ...]

# Configure with model-specific normalization
config = TransformConfig(
    preset="randaugment",
    image_size=384,
    use_timm_config=True,  # Auto-detect mean/std from backbone
)

model = ImageClassifier(
    backbone="efficientnet_b4",
    num_classes=10,
    transform_config=config,
)

Custom Auto-Tuning

from autotimm import AutoTrainer, TunerConfig

# Default: Full auto-tuning
trainer = AutoTrainer(max_epochs=10)

# Disable auto-tuning
trainer = AutoTrainer(max_epochs=10, tuner_config=False)

# Custom configuration
trainer = AutoTrainer(
    max_epochs=10,
    tuner_config=TunerConfig(
        auto_lr=True,
        auto_batch_size=True,
        lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1.0},
    ),
)

Optional Metrics for Inference

# Training with metrics
model = ImageClassifier(backbone="resnet50", num_classes=10, metrics=metrics)

# Inference without metrics
model = ImageClassifier(backbone="resnet50", num_classes=10)
model = model.load_from_checkpoint("checkpoint.ckpt")
predictions = model(image)

Explore Models

YOLOX Models

import autotimm

# List all YOLOX variants
autotimm.list_yolox_models()  # ['yolox-nano', 'yolox-tiny', 'yolox-s', ...]

# Get detailed specs (params, FLOPs, mAP)
autotimm.list_yolox_models(verbose=True)

# Get model info
info = autotimm.get_yolox_model_info("yolox-s")
print(f"Params: {info['params']}, mAP: {info['mAP']}")  # Params: 9.0M, mAP: 40.5

# List components
autotimm.list_yolox_backbones()
autotimm.list_yolox_necks()
autotimm.list_yolox_heads()

timm Backbones

# Search 1000+ timm models
autotimm.list_backbones("*efficientnet*", pretrained_only=True)
autotimm.list_backbones("*vit*")

# Search HuggingFace Hub
autotimm.list_hf_hub_backbones(model_name="resnet", limit=10)

# Inspect a model
backbone = autotimm.create_backbone("convnext_tiny")
print(f"Features: {backbone.num_features}, Params: {autotimm.count_parameters(backbone):,}")

Documentation & Examples

Documentation

Section Description
Quick Start Get up and running in 5 minutes
User Guide In-depth guides for all features
Interpretation Guide Model explainability and visualization
YOLOX Guide Complete YOLOX implementation guide
API Reference Complete API documentation
Examples 40+ runnable code examples

Ready-to-Run Examples

Classification

Object Detection

Segmentation

Interpretation & Explainability

HuggingFace & Advanced

Browse all examples →

Supported Architectures

Classification

  • Models: Any timm backbone (1000+)
  • Losses: CrossEntropy with label smoothing, Mixup

Object Detection

  • Architectures: FCOS, YOLOX (official & custom)
  • Losses: Focal Loss, GIoU Loss, Centerness Loss

Semantic Segmentation

  • Architectures: DeepLabV3+, FCN
  • Losses: CrossEntropy, Dice, Focal, Combined, Tversky
  • Formats: PNG masks, COCO stuff, Cityscapes, Pascal VOC

Instance Segmentation

  • Architecture: FCOS + Mask R-CNN style mask head
  • Losses: Detection losses + Binary mask loss

Testing

Comprehensive test suite with 200+ tests:

# Run all tests
pytest tests/ -v

# Specific modules
pytest tests/test_classification.py
pytest tests/test_yolox.py
pytest tests/test_interpretation.py
pytest tests/test_interpretation_metrics.py

# With coverage
pytest tests/ --cov=autotimm --cov-report=html

Contributing

We welcome contributions!

git clone https://github.com/theja-vanka/AutoTimm.git
cd AutoTimm
pip install -e ".[dev,all]"
pytest tests/ -v

For major changes, please open an issue first.

Citation

@software{autotimm2026,
  author = {Krishnatheja Vanka},
  title = {AutoTimm: Automatic PyTorch Image Models},
  url = {https://github.com/theja-vanka/AutoTimm},
  year = {2026},
  version = {0.7.1}
}

Built with ❤️ using timm and PyTorch Lightning

Star us on GitHubReport IssuesRead the Docs

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autotimm-0.7.1.tar.gz (4.2 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

autotimm-0.7.1-py3-none-any.whl (171.0 kB view details)

Uploaded Python 3

File details

Details for the file autotimm-0.7.1.tar.gz.

File metadata

  • Download URL: autotimm-0.7.1.tar.gz
  • Upload date:
  • Size: 4.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for autotimm-0.7.1.tar.gz
Algorithm Hash digest
SHA256 5b73e4e2cfae169e02d5ca6f35de78f1afc7a22faebd3da7da98edd6a441a02e
MD5 e06d6219b8102631d1368654dff3c519
BLAKE2b-256 3d3942189e1ec5f24f1c5d38f8a424e60dc5109327ff02c219aca97a376feb3f

See more details on using hashes here.

File details

Details for the file autotimm-0.7.1-py3-none-any.whl.

File metadata

  • Download URL: autotimm-0.7.1-py3-none-any.whl
  • Upload date:
  • Size: 171.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for autotimm-0.7.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b7410f9fe0bbf0fab19139cb6c3337685376823ee50ed6d6d6a6f8d2d80255aa
MD5 ef4bf77b190fc6dec30d102948ea59d6
BLAKE2b-256 ff1ea835b3823873ff6776f845c0c8f6030d0e1310b075c6dd8ba3ed9e684824

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page