Skip to main content

Test-Time Training / Adaptation engine collection for PyTorch.

Project description

TorchTTT

PyPI GitHub stars GitHub forks Documentation Testing Downloads Monthly Downloads Issues License

torch-ttt

torch-ttt is a comprehensive PyTorch library for Test-Time Training (TTT) and Test-Time Adaptation techniques. It helps make your neural networks more robust and generalizable to distribution shifts, corruptions, and out-of-distribution data—without requiring access to training data or labels at test time.

The library is designed to be modular, easy to integrate into existing PyTorch pipelines, and collaborative—we aim to include as many TTT methods as possible. If you've developed a TTT method, reach out to add yours!

>> You can find our webpage and documentation here: torch-ttt.github.io

torch-ttt is under active development. The API may change as we add new features and methods. Contributions are highly welcome! If you encounter any bugs or have feature requests, please submit an issue.

What is Test-Time Training?

Test-Time Training (TTT) is a paradigm where models adapt to test data during inference by optimizing self-supervised auxiliary objectives—without accessing training data or test labels. This helps models handle distribution shifts, corruptions, and out-of-distribution data.

Test-Time Training Schema

torch-ttt implements TTT methods through a unified Engine abstraction. Each Engine encapsulates the complete adaptation logic of a specific TTT method, allowing you to:

  • Wrap any PyTorch model with a single line of code
  • Switch between different TTT methods seamlessly
  • Adapt models at inference time without modifying your existing pipeline

This modular design makes it easy to experiment with different adaptation strategies and find the best approach for your specific use case.

Key Features

torch-ttt provides a streamlined API through Engines—lightweight wrappers around your PyTorch models. All Engines follow the same interface, making them easy to use and highly modular.

Test-Time Training Schema

You can add test-time adaptation with just a few lines of code, and switch between methods seamlessly. The library includes comprehensive tutorials and examples for every method, with efficient implementations suitable for both research and production deployment.

Check out the Quick Start guide or the API reference for more details.

Supported Methods

torch-ttt includes implementations of the following test-time training and adaptation methods:

Method Class Paper Description
TTT TTTEngine Sun et al. 2020 Original test-time training with self-supervised rotation prediction
TTT++ TTTPPEngine Liu et al. 2021 Improved TTT with contrastive learning
Masked TTT MaskedTTTEngine Gandelsman et al. 2022 Self-supervised masked reconstruction for adaptation
TENT TentEngine Wang et al. 2021 Entropy minimization for test-time adaptation
EATA EataEngine Niu et al. 2022 Efficient anti-catastrophic adaptation
MEMO MemoEngine Zhang et al. 2022 Marginal entropy minimization with one test point
ActMAD ActMADEngine Mirza et al. 2022 Activation matching for domain adaptation
DeYO DeYOEngine Mummadi et al. 2021 Test-time training with deep Y-shaped networks
IT3 IT3Engine Eastwood et al. 2024 Iterative test-time training

Want to see your method here? We welcome contributions!

Installation

Requirements: Python 3.10+, PyTorch 1.12+

Install from PyPI:

pip install torch-ttt

Install from source:

pip install git+https://github.com/nikitadurasov/torch-ttt.git

Quick Start

Here's a minimal example showing how to use torch-ttt to adapt a model at test time:

import torch
import torchvision.models as models
from torch_ttt.engine.tent_engine import TentEngine

# Load your pre-trained model
model = models.resnet50(pretrained=True)

# Wrap it with a TTT Engine (e.g., TENT for entropy minimization)
engine = TentEngine(
    model=model,
    optimization_parameters={
        "lr": 2e-3,
        "num_steps": 1
    }
)

# Switch to eval mode
engine.eval()

# At test time, adapt to new data
test_images = torch.randn(8, 3, 224, 224)

# The engine automatically adapts the model during forward
adapted_output = engine(test_images)

Documentation

Comprehensive documentation is available at torch-ttt.github.io, including:

Contributing

We welcome contributions! To add a new TTT method, report bugs, or improve documentation:

  1. Fork the repository
  2. Create a new engine inheriting from BaseEngine
  3. Add tests and documentation
  4. Submit a pull request

See GitHub Issues for bug reports and feature requests.

Citation

If you use torch-ttt in your research, please cite:

@software{durasov2024torchttt,
  author    = {Durasov, Nikita},
  title     = {torch-ttt: A Unified PyTorch Library for Test-Time Training},
  year      = {2024},
  doi       = {10.5281/zenodo.17620711},
  url       = {https://github.com/nikitadurasov/torch-ttt},
}

Also cite the original papers of the methods you use. See our Papers page.

License

MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_ttt-0.0.2.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_ttt-0.0.2-py3-none-any.whl (37.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_ttt-0.0.2.tar.gz.

File metadata

  • Download URL: torch_ttt-0.0.2.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.4

File hashes

Hashes for torch_ttt-0.0.2.tar.gz
Algorithm Hash digest
SHA256 b38a37fa1d39e74abd89b7a23766e0d5edfcc07a2066f4e756613e0f103dc136
MD5 172e1a5780531bfb94eca15bedf50468
BLAKE2b-256 6cdaf7b5a481d70446363e7d6828acac64d7864fb301dc6d8f44e8d92d10895a

See more details on using hashes here.

File details

Details for the file torch_ttt-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: torch_ttt-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 37.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.4

File hashes

Hashes for torch_ttt-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9891a4aa3d1b86a2d84801a560f8c516e242291b1b54988f194a5c9bdc73da13
MD5 64f862677989b128643bdc1dfc5d49b9
BLAKE2b-256 415c6b446d86a4c7102bc1cff63339f19b774fdc79c0d3b9e3f158a626719168

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page