No project description provided
Project description
News
November 19: Git repo is now public
Documentation
Google Colab Examples
See the examples folder for notebooks you can download or run on Google Colab.
Overview
This library consists of 11 modules:
Module | Description |
---|---|
Adapters | Wrappers for training and inference steps |
Containers | Dictionaries for simplifying object creation |
Datasets | Commonly used datasets and tools for domain adaptation |
Frameworks | Wrappers for training/testing pipelines |
Hooks | Modular building blocks for domain adaptation algorithms |
Layers | Loss functions and helper layers |
Meta Validators | Post-processing of metrics, for hyperparameter optimization |
Models | Architectures used for benchmarking and in examples |
Utils | Various tools |
Validators | Metrics for determining and estimating accuracy |
Weighters | Functions for weighting losses |
How to...
Use in vanilla PyTorch
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device
# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
data = batch_to_device(data, device)
# Optimization is done inside the hook.
# The returned loss is for logging.
loss, _ = hook({}, {**models, **data})
Build complex algorithms
Let's customize DANNHook
with:
- virtual adversarial training
- entropy conditioning
from pytorch_adapt.hooks import EntropyReducer, MeanReducer, VATHook
# G and C are the Generator and Classifier models
misc = {"combined_model": torch.nn.Sequential(G, C)}
reducer = EntropyReducer(
apply_to=["src_domain_loss", "target_domain_loss"], default_reducer=MeanReducer()
)
hook = DANNHook(optimizers, reducer=reducer, post_g=[VATHook()])
for data in tqdm(dataloader):
data = batch_to_device(data, device)
loss, _ = hook({}, {**models, **data, **misc})
Wrap with your favorite PyTorch framework
For additional functionality, adapters can be wrapped with a framework (currently just PyTorch Ignite).
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator
from pytorch_adapt.frameworks.ignite import Ignite
# Assume G, C and D are existing models
models_cont = Models(models)
# Override the default optimizer for G and C
optimizers_cont = Optimizers((torch.optim.Adam, {"lr": 0.123}), keys=["G", "C"])
adapter = DANN(models=models_cont, optimizers=optimizers_cont)
dc = DataloaderCreator(num_workers=2)
trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)
Wrappers for other frameworks (e.g. PyTorch Lightning and Catalyst) are planned to be added.
Check your model's performance
You can do this in vanilla PyTorch:
from pytorch_adapt.validators import SNDValidator
# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(epoch=1, target_train=target_train)
You can also do this using a framework wrapper:
validator = SNDValidator()
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)
Run the above examples
See this notebook and the examples page for other notebooks.
Installation
Pip
pip install pytorch-adapt
To get the latest dev version:
pip install pytorch-adapt --pre
To use pytorch_adapt.frameworks.ignite
:
pip install pytorch-adapt[ignite]
Conda
Coming soon...
Dependencies
Required dependencies:
- numpy
- torch >= 1.6
- torchvision
- torchmetrics
- pytorch-metric-learning >= 1.0.0.dev5
Acknowledgements
Contributors
Pull requests are welcome!
Advisors
Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.
Logo
Thanks to Jeff Musgrave for designing the logo.
Code references (in no particular order)
- https://github.com/wgchang/DSBN
- https://github.com/jihanyang/AFN
- https://github.com/thuml/Versatile-Domain-Adaptation
- https://github.com/tim-learn/ATDOC
- https://github.com/thuml/CDAN
- https://github.com/takerum/vat_chainer
- https://github.com/takerum/vat_tf
- https://github.com/RuiShu/dirt-t
- https://github.com/lyakaap/VAT-pytorch
- https://github.com/9310gaurav/virtual-adversarial-training
- https://github.com/thuml/Deep-Embedded-Validation
- https://github.com/lr94/abas
- https://github.com/thuml/Batch-Spectral-Penalization
- https://github.com/jvanvugt/pytorch-domain-adaptation
- https://github.com/ptrblck/pytorch_misc
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-adapt-0.0.37.tar.gz
.
File metadata
- Download URL: pytorch-adapt-0.0.37.tar.gz
- Upload date:
- Size: 79.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a2f93abfc05a8b4eeabf960b889a15148f1589a82131a822108a3e97e80d4ff3 |
|
MD5 | 6e9a8001d4248f448f0191e28faec365 |
|
BLAKE2b-256 | f144da1f4d8852428eccfde522b3335288dd24550f7705fc6627a2af4628ea0b |
File details
Details for the file pytorch_adapt-0.0.37-py3-none-any.whl
.
File metadata
- Download URL: pytorch_adapt-0.0.37-py3-none-any.whl
- Upload date:
- Size: 127.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0dadb8f07d28a5f1779c01f978e7387fcaacb33736b8bb6c98faff295eecc275 |
|
MD5 | 99bab8f15bf1ccef7e650f75beb72347 |
|
BLAKE2b-256 | 37d2fe5ebc97ee0f9960afd17a03d14e58c52c014f672f050bd080450c99bc52 |