No project description provided
Project description
Why use PyTorch Adapt?
PyTorch Adapt provides tools for domain adaptation, a type of machine learning algorithm that repurposes existing models to work in new domains. This library is:
1. Fully featured
Build a complete train/val domain adaptation pipeline in a few lines of code.
2. Modular
Use just the parts that suit your needs, whether it's the algorithms, loss functions, or validation methods.
3. Highly customizable
Customize and combine complex algorithms with ease.
4. Compatible with frameworks
Add additional functionality to your code by using one of the framework wrappers. Converting an algorithm into a PyTorch Lightning module is as simple as wrapping it with Lightning
.
Documentation
Getting started
See the examples folder for notebooks you can download or run on Google Colab.
How to...
Use in vanilla PyTorch
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device
# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
data = batch_to_device(data, device)
# Optimization is done inside the hook.
# The returned loss is for logging.
loss, _ = hook({}, {**models, **data})
Build complex algorithms
Let's customize DANNHook
with:
- minimum class confusion
- virtual adversarial training
from pytorch_adapt.hooks import MCCHook, VATHook
# G and C are the Generator and Classifier models
G, C = models["G"], models["C"]
misc = {"combined_model": torch.nn.Sequential(G, C)}
hook = DANNHook(optimizers, post_g=[MCCHook(), VATHook()])
for data in tqdm(dataloader):
data = batch_to_device(data, device)
loss, _ = hook({}, {**models, **data, **misc})
Wrap with your favorite PyTorch framework
First, set up the adapter and dataloaders:
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models
from pytorch_adapt.datasets import DataloaderCreator
models_cont = Models(models)
adapter = DANN(models=models_cont)
dc = DataloaderCreator(num_workers=2)
dataloaders = dc(**datasets)
Then use a framework wrapper:
PyTorch Lightning
import pytorch_lightning as pl
from pytorch_adapt.frameworks.lightning import Lightning
L_adapter = Lightning(adapter)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, dataloaders["train"])
PyTorch Ignite
trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)
Check your model's performance
You can do this in vanilla PyTorch:
from pytorch_adapt.validators import SNDValidator
# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(target_train=target_train)
You can also do this during training with a framework wrapper:
PyTorch Lightning
from pytorch_adapt.frameworks.utils import filter_datasets
validator = SNDValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")
L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, train_loader, list(dataloaders.values()))
Pytorch Ignite
from pytorch_adapt.validators import ScoreHistory
validator = ScoreHistory(SNDValidator())
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)
Run the above examples
See this notebook and the examples page for other notebooks.
Installation
Pip
pip install pytorch-adapt
To get the latest dev version:
pip install pytorch-adapt --pre
To use pytorch_adapt.frameworks.lightning
:
pip install pytorch-adapt[lightning]
To use pytorch_adapt.frameworks.ignite
:
pip install pytorch-adapt[ignite]
Conda
Coming soon...
Dependencies
Required dependencies:
- numpy
- torch >= 1.6
- torchvision
- torchmetrics
- pytorch-metric-learning >= 1.0.0.dev5
Acknowledgements
Contributors
Pull requests are welcome!
Advisors
Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.
Logo
Thanks to Jeff Musgrave for designing the logo.
Code references (in no particular order)
- https://github.com/wgchang/DSBN
- https://github.com/jihanyang/AFN
- https://github.com/thuml/Versatile-Domain-Adaptation
- https://github.com/tim-learn/ATDOC
- https://github.com/thuml/CDAN
- https://github.com/takerum/vat_chainer
- https://github.com/takerum/vat_tf
- https://github.com/RuiShu/dirt-t
- https://github.com/lyakaap/VAT-pytorch
- https://github.com/9310gaurav/virtual-adversarial-training
- https://github.com/thuml/Deep-Embedded-Validation
- https://github.com/lr94/abas
- https://github.com/thuml/Batch-Spectral-Penalization
- https://github.com/jvanvugt/pytorch-domain-adaptation
- https://github.com/ptrblck/pytorch_misc
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-adapt-0.0.43.dev2.tar.gz
.
File metadata
- Download URL: pytorch-adapt-0.0.43.dev2.tar.gz
- Upload date:
- Size: 83.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 376597921529a110877942aea86f6f827ab677ad87659f2d2a57271b98180728 |
|
MD5 | e9fca631847ccb2e8413041f3d65527e |
|
BLAKE2b-256 | d00f0e287334ceebe65316f4d65a30c02e75c403cb7acbd490b19f351d9e92a0 |
File details
Details for the file pytorch_adapt-0.0.43.dev2-py3-none-any.whl
.
File metadata
- Download URL: pytorch_adapt-0.0.43.dev2-py3-none-any.whl
- Upload date:
- Size: 133.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e51b5b9b12cf126b52d527fa34234441711443d5c673b21d5ff0c09db2dc9364 |
|
MD5 | c844055df6475fd037d1896b23842f43 |
|
BLAKE2b-256 | 82287989e292e84f624d245a94c22bc4ba5bf940f90e28625f8dd7222d6cbfc4 |