No project description provided
Project description
News
November 19: Git repo is now public
Documentation
Google Colab Examples
See the examples folder for notebooks you can download or run on Google Colab.
Overview
This library consists of 11 modules:
Module | Description |
---|---|
Adapters | Wrappers for training and inference steps |
Containers | Dictionaries for simplifying object creation |
Datasets | Commonly used datasets and tools for domain adaptation |
Frameworks | Wrappers for training/testing pipelines |
Hooks | Modular building blocks for domain adaptation algorithms |
Layers | Loss functions and helper layers |
Meta Validators | Post-processing of metrics, for hyperparameter optimization |
Models | Architectures used for benchmarking and in examples |
Utils | Various tools |
Validators | Metrics for determining and estimating accuracy |
Weighters | Functions for weighting losses |
How to...
Use in vanilla PyTorch
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device
# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in dataloader:
data = batch_to_device(data, device)
# Optimization is done inside the hook.
# The returned loss is for logging.
loss, _ = hook({}, {**models, **data})
Build complex algorithms
Let's customize DANNHook
with:
- virtual adversarial training
- entropy conditioning
from pytorch_adapt.hooks import EntropyReducer, MeanReducer, VATHook
# G and C are the Generator and Classifier models
misc = {"combined_model": torch.nn.Sequential(G, C)}
reducer = EntropyReducer(
apply_to=["src_domain_loss", "target_domain_loss"], default_reducer=MeanReducer()
)
hook = DANNHook(optimizers, reducer=reducer, post_g=[VATHook()])
for data in dataloader:
data = batch_to_device(data, device)
loss, _ = hook({}, {**models, **data, **misc})
Remove some boilerplate
Adapters and containers can simplify object creation.
import torch
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
# Assume G, C and D are existing models
models = Models(models)
# Override the default optimizer for G and C
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.123}), keys=["G", "C"])
adapter = DANN(models=models, optimizers=optimizers)
for data in dataloader:
adapter.training_step(data, device)
Wrap with your favorite PyTorch framework
For additional functionality, adapters can be wrapped with a framework (currently just PyTorch Ignite.)
from pytorch_adapt.frameworks import Ignite
wrapped_adapter = Ignite(adapter)
wrapped_adapter.run(datasets)
Wrappers for other frameworks (e.g. PyTorch Lightning and Catalyst) are coming soon.
Check accuracy of your model
You can do this in vanilla PyTorch:
from pytorch_adapt.validators import SNDValidator
# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(epoch=1, target_train=target_train)
You can also do this using a framework wrapper:
from pytorch_adapt.validators import SNDValidator
validator = SNDValidator()
wrapped_adapter.run(datasets, validator=validator)
Load a toy dataset
import torch
from pytorch_adapt.datasets import get_mnist_mnistm
# mnist is the source domain
# mnistm is the target domain
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], ".", download=True)
dataloader = torch.utils.data.DataLoader(
datasets["train"], batch_size=32, num_workers=2
)
Run the above examples
See this notebook and the examples page for other notebooks.
Installation
Pip
pip install pytorch-adapt
To get the latest dev version:
pip install pytorch-adapt --pre
Conda
Coming soon...
Dependencies
Coming soon...
Acknowledgements
Contributors
Pull requests are welcome!
Advisors
Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.
Logo
Thanks to Jeff Musgrave for designing the logo.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_adapt-0.0.29-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 91bb127fb84a54cfa83f0670de78b9b351cc982bcc45d0bf64a9ada3a590648e |
|
MD5 | 483f592b0ad89edbde8c7e582878792a |
|
BLAKE2b-256 | d6e18ced1d67dd9ac8b6cf226059c80c037a0785bbf7251ce84a290cb62948b6 |