Skip to main content

Domain adaptation made easy. Fully featured, modular, and customizable.

Project description

PyTorch Adapt

PyPi version

Why use PyTorch Adapt?

PyTorch Adapt provides tools for domain adaptation, a type of machine learning algorithm that repurposes existing models to work in new domains. This library is:

1. Fully featured

Build a complete train/val domain adaptation pipeline in a few lines of code.

2. Modular

Use just the parts that suit your needs, whether it's the algorithms, loss functions, or validation methods.

3. Highly customizable

Customize and combine complex algorithms with ease.

4. Compatible with frameworks

Add additional functionality to your code by using one of the framework wrappers. Converting an algorithm into a PyTorch Lightning module is as simple as wrapping it with Lightning.

Documentation

Examples

See the examples folder for notebooks you can download or run on Google Colab.

How to...

Use in vanilla PyTorch

from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device

# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    # Optimization is done inside the hook.
    # The returned loss is for logging.
    _, loss = hook({**models, **data})

Build complex algorithms

Let's customize DANNHook with:

  • minimum class confusion
  • virtual adversarial training
from pytorch_adapt.hooks import MCCHook, VATHook

# G and C are the Generator and Classifier models
G, C = models["G"], models["C"]
misc = {"combined_model": torch.nn.Sequential(G, C)}
hook = DANNHook(optimizers, post_g=[MCCHook(), VATHook()])
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    _, loss = hook({**models, **data, **misc})

Wrap with your favorite PyTorch framework

First, set up the adapter and dataloaders:

from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models
from pytorch_adapt.datasets import DataloaderCreator

models_cont = Models(models)
adapter = DANN(models=models_cont)
dc = DataloaderCreator(num_workers=2)
dataloaders = dc(**datasets)

Then use a framework wrapper:

PyTorch Lightning

import pytorch_lightning as pl
from pytorch_adapt.frameworks.lightning import Lightning

L_adapter = Lightning(adapter)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, dataloaders["train"])

PyTorch Ignite

trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)

Check your model's performance

You can do this in vanilla PyTorch:

from pytorch_adapt.validators import SNDValidator

# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator(target_train=target_train)

You can also do this during training with a framework wrapper:

PyTorch Lightning

from pytorch_adapt.frameworks.utils import filter_datasets

validator = SNDValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")

L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, train_loader, list(dataloaders.values()))

Pytorch Ignite

from pytorch_adapt.validators import ScoreHistory

validator = ScoreHistory(SNDValidator())
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)

Run the above examples

See this notebook and the examples page for other notebooks.

Installation

Pip

pip install pytorch-adapt

To get the latest dev version:

pip install pytorch-adapt --pre

To use pytorch_adapt.frameworks.lightning:

pip install pytorch-adapt[lightning]

To use pytorch_adapt.frameworks.ignite:

pip install pytorch-adapt[ignite]

Conda

Coming soon...

Dependencies

See setup.py

Acknowledgements

Contributors

Thanks to the contributors who made pull requests!

Contributor Highlights
deepseek-eoghan Improved the TargetDataset class

Advisors

Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.

Logo

Thanks to Jeff Musgrave for designing the logo.

Citing this library

If you'd like to cite pytorch-adapt in your paper, you can refer to this paper by copy-pasting this bibtex reference:

@article{Musgrave2022PyTorchA,
  title={PyTorch Adapt},
  author={Kevin Musgrave and Serge J. Belongie and Ser Nam Lim},
  journal={ArXiv},
  year={2022},
  volume={abs/2211.15673}
}

Code references (in no particular order)

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-adapt-0.0.83.tar.gz (95.5 kB view details)

Uploaded Source

Built Distribution

pytorch_adapt-0.0.83-py3-none-any.whl (158.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-adapt-0.0.83.tar.gz.

File metadata

  • Download URL: pytorch-adapt-0.0.83.tar.gz
  • Upload date:
  • Size: 95.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for pytorch-adapt-0.0.83.tar.gz
Algorithm Hash digest
SHA256 6b9dbfa5cb1ac55c7223bb89ec19bbc18d3cc17bb74fad9af6d17d0af717f69e
MD5 52168fe8ce709c40b145e07905658192
BLAKE2b-256 dbe596520821bbb5f2f38d3f77458e9b47e155b37a35e1e577b34f6dd5a55a49

See more details on using hashes here.

File details

Details for the file pytorch_adapt-0.0.83-py3-none-any.whl.

File metadata

  • Download URL: pytorch_adapt-0.0.83-py3-none-any.whl
  • Upload date:
  • Size: 158.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for pytorch_adapt-0.0.83-py3-none-any.whl
Algorithm Hash digest
SHA256 d27109f0488f3c76ca4d3a7e3367bf6a69dd0fb246246f2de013d7710509cd05
MD5 bc0f25c0e0833a74aff488aa983afac7
BLAKE2b-256 c871d37e9f35faa1575092e7ffba9556b0f01c71be4c6ffada31eed5d47e928d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page