Automatic Pytorch Image Models
Project description
Train state-of-the-art vision models with minimal code
From prototype to production in minutes, not hours
Documentation • Quick Start • Examples • API Reference
What is AutoTimm?
AutoTimm is a production-ready computer vision framework that combines timm (1000+ pretrained models) with PyTorch Lightning. Train image classifiers, object detectors, and segmentation models with any timm backbone using a simple, intuitive API.
Perfect for:
- Researchers needing reproducible experiments and quick iterations
- Engineers building production ML systems with minimal boilerplate
- Students learning computer vision with modern best practices
- Startups rapidly prototyping vision applications
Why AutoTimm?
FastFrom idea to trained model in minutes. Auto-tuning, mixed precision, and multi-GPU out of the box. |
Flexible1000+ backbones, 4 vision tasks, multiple transform backends. Use what works best. |
Production Ready200+ tests, comprehensive logging, checkpoint management. Deploy with confidence. |
What's New in v0.7.1
- Model Interpretation - Complete explainability toolkit with 6 interpretation methods, 6 quality metrics, interactive Plotly visualizations, and up to 100x speedup with optimization
- Tutorial Notebook - Comprehensive Jupyter notebook covering all interpretation features end-to-end
- YOLOX Models - Official YOLOX implementation (nano to X) with CSPDarknet backbone
- Smart Backend Selection - AI-powered recommendation for optimal transform backends
- TransformConfig - Unified transform configuration with presets and model-specific normalization
- Optional Metrics - Metrics now optional for inference-only deployments
- Python 3.10-3.14 - Latest Python support
Quick Start
Installation
pip install autotimm
Everything included: PyTorch, timm, PyTorch Lightning, torchmetrics, albumentations, pycocotools, and more.
Optional extras
# Logging backends
pip install autotimm[tensorboard] # TensorBoard
pip install autotimm[wandb] # Weights & Biases
pip install autotimm[mlflow] # MLflow
# Interpretation
pip install autotimm[interactive] # Interactive Plotly visualizations
# All extras
pip install autotimm[all] # Everything
Your First Model in 30 Seconds
from autotimm import AutoTrainer, ImageClassifier, ImageDataModule, MetricConfig
# Data
data = ImageDataModule(
data_dir="./data",
dataset_name="CIFAR10",
image_size=224,
batch_size=64,
)
# Metrics
metrics = [
MetricConfig(
name="accuracy",
backend="torchmetrics",
metric_class="Accuracy",
params={"task": "multiclass"},
stages=["train", "val", "test"],
prog_bar=True,
)
]
# Model
model = ImageClassifier(
backbone="resnet18", # Try efficientnet_b0, vit_base_patch16_224, etc.
num_classes=10,
metrics=metrics,
lr=1e-3,
)
# Train with auto-tuning (finds optimal LR and batch size automatically!)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
Auto-tuning is enabled by default. Disable with
tuner_config=Falsefor manual control.
Key Features
| 4 Vision Tasks | Classification • Object Detection • Semantic Segmentation • Instance Segmentation |
| 1000+ Backbones | ResNet • EfficientNet • ViT • ConvNeXt • Swin • DeiT • BEiT • and more from timm |
| Model Interpretation | 6 explanation methods • 6 quality metrics • Interactive visualizations • Up to 100x speedup |
| HuggingFace Integration | Load models from HF Hub with hf-hub: prefix + Direct Transformers support |
| YOLOX Support | Official YOLOX models (nano → X) + YOLOX-style heads with any timm backbone |
| Advanced Architectures | DeepLabV3+ • FCOS • YOLOX • Mask R-CNN • Feature Pyramids |
| Auto-Tuning | Automatic LR and batch size finding—enabled by default |
| Smart Transforms | AI-powered backend recommendations + unified TransformConfig with presets |
| Multi-Logger Support | TensorBoard • MLflow • Weights & Biases • CSV—use simultaneously |
| Production Ready | Mixed precision • Multi-GPU • Gradient accumulation • 200+ tests |
Task Examples
Image Classification
from autotimm import ImageClassifier
# Use any timm backbone or HuggingFace model
model = ImageClassifier(
backbone="efficientnet_b0", # or "hf-hub:timm/resnet50.a1_in1k"
num_classes=10,
metrics=metrics, # Optional for inference!
)
trainer = AutoTrainer(max_epochs=10)
trainer.fit(model, datamodule=data)
Object Detection with YOLOX
Official YOLOX (matches paper benchmarks):
from autotimm import YOLOXDetector, DetectionDataModule
model = YOLOXDetector(
model_name="yolox-s", # nano, tiny, s, m, l, x
num_classes=80,
lr=0.01,
optimizer="sgd",
scheduler="yolox",
total_epochs=300,
)
trainer = AutoTrainer(max_epochs=300, precision="16-mixed")
trainer.fit(model, datamodule=DetectionDataModule(data_dir="./coco", image_size=640))
YOLOX-style head with any timm backbone:
from autotimm import ObjectDetector
model = ObjectDetector(
backbone="resnet50", # Experiment with any backbone!
num_classes=80,
detection_arch="yolox",
fpn_channels=256,
)
Complete YOLOX Guide • Quick Reference
Semantic Segmentation
from autotimm import SemanticSegmentor, SegmentationDataModule
model = SemanticSegmentor(
backbone="resnet50",
num_classes=19,
head_type="deeplabv3plus",
loss_type="combined", # CE + Dice for better boundaries
)
data = SegmentationDataModule(
data_dir="./cityscapes",
format="cityscapes", # or "coco", "voc", "png"
image_size=512,
)
trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=data)
Instance Segmentation
from autotimm import InstanceSegmentor, InstanceSegmentationDataModule
model = InstanceSegmentor(
backbone="resnet50",
num_classes=80,
mask_loss_weight=1.0,
)
trainer = AutoTrainer(max_epochs=100)
trainer.fit(model, datamodule=InstanceSegmentationDataModule(data_dir="./coco"))
Model Interpretation & Explainability
Understand what your models learn and how they make decisions with comprehensive interpretation tools.
Quick Explanation
from autotimm.interpretation import quick_explain
# One-line explanation
result = quick_explain(
model,
image,
method="gradcam",
save_path="explanation.png"
)
6 Interpretation Methods
from autotimm.interpretation import (
GradCAM, # Fast, class-discriminative (CNNs)
GradCAMPlusPlus, # Better for multiple objects
IntegratedGradients, # Theoretically sound, pixel-level
SmoothGrad, # Noise-reduced gradients
AttentionRollout, # Vision Transformers
AttentionFlow, # Vision Transformers
)
# Use any method
explainer = GradCAM(model)
heatmap = explainer.explain(image, target_class=5)
explainer.visualize(image, heatmap, save_path="gradcam.png")
Quantitative Evaluation
from autotimm.interpretation import ExplanationMetrics
metrics = ExplanationMetrics(model, explainer)
# Faithfulness metrics
deletion = metrics.deletion(image, target_class=5, steps=50)
insertion = metrics.insertion(image, target_class=5, steps=50)
# Stability metric
sensitivity = metrics.sensitivity_n(image, n_samples=50)
# Sanity checks
param_test = metrics.model_parameter_randomization_test(image)
data_test = metrics.data_randomization_test(image)
# Localization metric
pointing = metrics.pointing_game(image, bbox=(50, 50, 150, 150))
print(f"Deletion AUC: {deletion['auc']:.4f}") # Lower = better
print(f"Insertion AUC: {insertion['auc']:.4f}") # Higher = better
print(f"Sensitivity: {sensitivity['sensitivity']:.4f}") # Lower = more stable
Interactive Visualizations
from autotimm.interpretation import InteractiveVisualizer
viz = InteractiveVisualizer(model)
# Create interactive HTML with zoom/pan/hover
fig = viz.visualize_explanation(
image,
explainer,
colorscale="Viridis",
save_path="interactive.html"
)
# Compare methods side-by-side
explainers = {
'GradCAM': GradCAM(model),
'GradCAM++': GradCAMPlusPlus(model),
'Integrated Gradients': IntegratedGradients(model),
}
viz.compare_methods(image, explainers, save_path="comparison.html")
# Generate comprehensive report
viz.create_report(image, explainer, save_path="report.html")
Performance Optimization
from autotimm.interpretation.optimization import (
ExplanationCache, # 10-50x speedup
BatchProcessor, # 2-5x speedup
PerformanceProfiler, # Identify bottlenecks
optimize_for_inference, # 1.5-3x speedup
)
# Enable caching
cache = ExplanationCache(cache_dir="./cache", max_size_mb=5000)
# Optimize model
model = optimize_for_inference(model, use_fp16=True)
# Batch processing
processor = BatchProcessor(model, explainer, batch_size=32)
heatmaps = processor.process_batch(images)
# Profile performance
profiler = PerformanceProfiler(enabled=True)
with profiler.profile("explanation"):
heatmap = explainer.explain(image)
profiler.print_stats()
Training Integration
from autotimm import AutoTrainer
from autotimm.interpretation import InterpretationCallback
# Monitor interpretations during training
callback = InterpretationCallback(
sample_images=val_images,
method="gradcam",
log_every_n_epochs=5,
)
trainer = AutoTrainer(
max_epochs=100,
callbacks=[callback],
logger="tensorboard",
)
trainer.fit(model, datamodule=data)
Features:
- 6 interpretation methods for different use cases
- 6 quality metrics for quantitative evaluation
- Interactive visualizations with Plotly (zoom/pan/hover)
- Up to 100x speedup with caching and optimization
- Feature visualization and receptive field analysis
- Training callbacks for automatic monitoring
- Comprehensive tutorial notebook included
Interpretation Guide • Tutorial Notebook
HuggingFace Integration
Three Approaches
| Approach | Best For | Example |
|---|---|---|
| HF Hub timm | CNNs, Production | "hf-hub:timm/resnet50.a1_in1k" |
| HF Transformers Direct | Vision Transformers | ViTModel.from_pretrained(...) |
| HF Transformers Auto | Quick Prototyping | AutoModel.from_pretrained(...) |
All approaches fully support AutoTrainer (checkpointing, early stopping, mixed precision, multi-GPU, auto-tuning).
from autotimm import ImageClassifier, list_hf_hub_backbones
# Discover models
models = list_hf_hub_backbones(model_name="resnet", limit=5)
# Use any HF Hub model (just add 'hf-hub:' prefix!)
model = ImageClassifier(
backbone="hf-hub:timm/convnext_base.fb_in22k_ft_in1k",
num_classes=100,
)
HF Integration Comparison • HF Hub Guide • HF Transformers Guide
Smart Features
Smart Backend Selection
from autotimm import recommend_backend, compare_backends
# Get AI-powered recommendation
rec = recommend_backend(task="detection")
config = rec.to_config(image_size=640)
# Compare backends side-by-side
compare_backends()
Unified Transform Configuration
from autotimm import TransformConfig, list_transform_presets
# Discover presets
list_transform_presets() # ['default', 'autoaugment', 'randaugment', ...]
# Configure with model-specific normalization
config = TransformConfig(
preset="randaugment",
image_size=384,
use_timm_config=True, # Auto-detect mean/std from backbone
)
model = ImageClassifier(
backbone="efficientnet_b4",
num_classes=10,
transform_config=config,
)
Custom Auto-Tuning
from autotimm import AutoTrainer, TunerConfig
# Default: Full auto-tuning
trainer = AutoTrainer(max_epochs=10)
# Disable auto-tuning
trainer = AutoTrainer(max_epochs=10, tuner_config=False)
# Custom configuration
trainer = AutoTrainer(
max_epochs=10,
tuner_config=TunerConfig(
auto_lr=True,
auto_batch_size=True,
lr_find_kwargs={"min_lr": 1e-6, "max_lr": 1.0},
),
)
Optional Metrics for Inference
# Training with metrics
model = ImageClassifier(backbone="resnet50", num_classes=10, metrics=metrics)
# Inference without metrics
model = ImageClassifier(backbone="resnet50", num_classes=10)
model = model.load_from_checkpoint("checkpoint.ckpt")
predictions = model(image)
Explore Models
YOLOX Models
import autotimm
# List all YOLOX variants
autotimm.list_yolox_models() # ['yolox-nano', 'yolox-tiny', 'yolox-s', ...]
# Get detailed specs (params, FLOPs, mAP)
autotimm.list_yolox_models(verbose=True)
# Get model info
info = autotimm.get_yolox_model_info("yolox-s")
print(f"Params: {info['params']}, mAP: {info['mAP']}") # Params: 9.0M, mAP: 40.5
# List components
autotimm.list_yolox_backbones()
autotimm.list_yolox_necks()
autotimm.list_yolox_heads()
timm Backbones
# Search 1000+ timm models
autotimm.list_backbones("*efficientnet*", pretrained_only=True)
autotimm.list_backbones("*vit*")
# Search HuggingFace Hub
autotimm.list_hf_hub_backbones(model_name="resnet", limit=10)
# Inspect a model
backbone = autotimm.create_backbone("convnext_tiny")
print(f"Features: {backbone.num_features}, Params: {autotimm.count_parameters(backbone):,}")
Documentation & Examples
Documentation
| Section | Description |
|---|---|
| Quick Start | Get up and running in 5 minutes |
| User Guide | In-depth guides for all features |
| Interpretation Guide | Model explainability and visualization |
| YOLOX Guide | Complete YOLOX implementation guide |
| API Reference | Complete API documentation |
| Examples | 40+ runnable code examples |
Ready-to-Run Examples
Classification
- classify_cifar10.py - Basic classification with auto-tuning
- classify_custom_folder.py - Train on custom dataset
- vit_finetuning.py - Two-phase ViT fine-tuning
- inference_without_metrics.py - Production deployment
Object Detection
- yolox_official.py - Official YOLOX models
- object_detection_yolox.py - YOLOX-style with timm
- object_detection_coco.py - FCOS detection
- object_detection_rtdetr.py - RT-DETR (no NMS!)
- explore_yolox_models.py - Interactive YOLOX explorer
Segmentation
- semantic_segmentation.py - DeepLabV3+
- instance_segmentation.py - Mask R-CNN style
Interpretation & Explainability
- comprehensive_interpretation_tutorial.ipynb - Complete tutorial (40+ cells)
- interpretation_metrics_demo.py - Quality metrics
- interactive_visualization_demo.py - Plotly visualizations
- performance_optimization_demo.py - Caching & optimization
HuggingFace & Advanced
- huggingface_hub_models.py - HF Hub basics
- hf_hub_*.py - Comprehensive HF examples
- multi_gpu_training.py - Distributed training
- mlflow_tracking.py - MLflow tracking
- preset_manager.py - Smart backend selection
Supported Architectures
Classification
- Models: Any timm backbone (1000+)
- Losses: CrossEntropy with label smoothing, Mixup
Object Detection
- Architectures: FCOS, YOLOX (official & custom)
- Losses: Focal Loss, GIoU Loss, Centerness Loss
Semantic Segmentation
- Architectures: DeepLabV3+, FCN
- Losses: CrossEntropy, Dice, Focal, Combined, Tversky
- Formats: PNG masks, COCO stuff, Cityscapes, Pascal VOC
Instance Segmentation
- Architecture: FCOS + Mask R-CNN style mask head
- Losses: Detection losses + Binary mask loss
Testing
Comprehensive test suite with 200+ tests:
# Run all tests
pytest tests/ -v
# Specific modules
pytest tests/test_classification.py
pytest tests/test_yolox.py
pytest tests/test_interpretation.py
pytest tests/test_interpretation_metrics.py
# With coverage
pytest tests/ --cov=autotimm --cov-report=html
Contributing
We welcome contributions!
git clone https://github.com/theja-vanka/AutoTimm.git
cd AutoTimm
pip install -e ".[dev,all]"
pytest tests/ -v
For major changes, please open an issue first.
Citation
@software{autotimm2026,
author = {Krishnatheja Vanka},
title = {AutoTimm: Automatic PyTorch Image Models},
url = {https://github.com/theja-vanka/AutoTimm},
year = {2026},
version = {0.7.1}
}
Built with ❤️ using timm and PyTorch Lightning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file autotimm-0.7.1.tar.gz.
File metadata
- Download URL: autotimm-0.7.1.tar.gz
- Upload date:
- Size: 4.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5b73e4e2cfae169e02d5ca6f35de78f1afc7a22faebd3da7da98edd6a441a02e
|
|
| MD5 |
e06d6219b8102631d1368654dff3c519
|
|
| BLAKE2b-256 |
3d3942189e1ec5f24f1c5d38f8a424e60dc5109327ff02c219aca97a376feb3f
|
File details
Details for the file autotimm-0.7.1-py3-none-any.whl.
File metadata
- Download URL: autotimm-0.7.1-py3-none-any.whl
- Upload date:
- Size: 171.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b7410f9fe0bbf0fab19139cb6c3337685376823ee50ed6d6d6a6f8d2d80255aa
|
|
| MD5 |
ef4bf77b190fc6dec30d102948ea59d6
|
|
| BLAKE2b-256 |
ff1ea835b3823873ff6776f845c0c8f6030d0e1310b075c6dd8ba3ed9e684824
|