No project description provided
Project description
News
Documentation
Google Colab Examples
See the examples folder for notebooks you can download or run on Google Colab.
Overview
This library consists of 11 modules:
Module | Description |
---|---|
Adapters | Wrappers for training and inference steps |
Containers | Dictionaries for simplifying object creation |
Datasets | Commonly used datasets and tools for domain adaptation |
Frameworks | Wrappers for training/testing pipelines |
Hooks | Modular building blocks for domain adaptation algorithms |
Layers | Loss functions and helper layers |
Meta Validators | Post-processing of metrics, for hyperparameter optimization |
Models | Architectures used for benchmarking and in examples |
Utils | Various tools |
Validators | Metrics for determining and estimating accuracy |
Weighters | Functions for weighting losses |
How to...
Use in vanilla PyTorch
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device
# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in dataloader:
data = batch_to_device(data, device)
# Optimization is done inside the hook.
# The returned loss is for logging.
loss, _ = hook({}, {**models, **data})
Build complex algorithms
Let's customize DANNHook
with:
- virtual adversarial training
- entropy conditioning
from pytorch_adapt.hooks import EntropyReducer, MeanReducer, VATHook
# G and C are the Generator and Classifier models
misc = {"combined_model": torch.nn.Sequential(G, C)}
reducer = EntropyReducer(
apply_to=["src_domain_loss", "target_domain_loss"], default_reducer=MeanReducer()
)
hook = DANNHook(optimizers, reducer=reducer, post_g=[VATHook()])
for data in dataloader:
data = batch_to_device(data, device)
loss, _ = hook({}, {**models, **data, **misc})
Remove some boilerplate
Adapters and containers can simplify object creation.
import torch
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
# Assume G, C and D are existing models
models = Models(models)
# Override the default optimizer for G and C
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.123}), keys=["G", "C"])
adapter = DANN(models=models, optimizers=optimizers)
for data in dataloader:
adapter.training_step(data, device)
Wrap with your favorite PyTorch framework
For additional functionality, adapters can be wrapped with a framework (currently just PyTorch Ignite.)
from pytorch_adapt.frameworks import Ignite
wrapped_adapter = Ignite(adapter)
wrapped_adapter.run(datasets)
Wrappers for other frameworks (e.g. PyTorch Lightning and Catalyst) is coming soon.
Check accuracy of your model
You can do this in vanilla PyTorch:
from pytorch_adapt.validators import SNDValidator
# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(epoch=1, target_train=target_train)
You can also do this using a framework wrapper:
from pytorch_adapt.validators import SNDValidator
validator = SNDValidator()
wrapped_adapter.run(datasets, validator=validator)
Load a toy dataset
import torch
from pytorch_adapt.datasets import get_mnist_mnistm
# mnist is the source domain
# mnistm is the target domain
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], ".")
dataloader = torch.utils.data.DataLoader(
datasets["train"], batch_size=32, num_workers=2
)
Run the above examples
See this notebook and the examples page for other notebooks.
Installation
Pip
pip install pytorch-adapt
To get the latest dev version:
pip install pytorch-adapt --pre
Conda
Coming soon...
Dependencies
Coming soon...
Acknowledgements
Contributors
Pull requests are welcome!
Advisors
Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.
Logo
Thanks to Jeff Musgrave for designing the logo.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-adapt-0.0.25.tar.gz
.
File metadata
- Download URL: pytorch-adapt-0.0.25.tar.gz
- Upload date:
- Size: 76.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c129cc6603d3e01b34ffd9efae1a0354b8199410a44114e89ae9e614c8b6be1d |
|
MD5 | 421fffb665a3bce9b75408e6d50f639c |
|
BLAKE2b-256 | c435c2d09a4264991ea05801b7dd6a274c967ec2e24b61693bf244b15899ec89 |
File details
Details for the file pytorch_adapt-0.0.25-py3-none-any.whl
.
File metadata
- Download URL: pytorch_adapt-0.0.25-py3-none-any.whl
- Upload date:
- Size: 123.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.8.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 34cc4dc8726934cf055c26c683212df118c3c8e7eaeccc90091556fa27856dec |
|
MD5 | 86d9f0a3eec8faa8e380784145f5fab3 |
|
BLAKE2b-256 | 33bfd8f131f7683ab933fe05808df835ebf200321ab01fa506df911e73031c60 |