Semi-supervised Adaptive Learning Across Domains
Project description
🥗 salad
========
**S**\ emi-supervised **A**\ daptive **L**\ earning **A**\ cross **D**\ omains
.. figure:: img/domainshift.png
:alt:
``salad`` is a library to easily setup experiments using the current
state-of-the art techniques in domain adaptation. It features several of
recent approaches, with the goal of being able to run fair comparisons
between algorithms and transfer them to real-world use cases. The
toolbox is under active development and will extended when new
approaches are published.
Contribute on Github: `https://github.com/domainadaptation/salad`_
Currently implements the following techniques (in ``salad.solver``)
- VADA (``VADASolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- DIRT-T (``DIRTTSolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- Self-Ensembling for Visual Domain Adaptation
(``SelfEnsemblingSolver``)
`arxiv:1706.05208 <https://arxiv.org/abs/1706.05208>`__
- Associative Domain Adaptation (``AssociativeSolver``),
`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__
- Domain Adversarial Training (``DANNSolver``),
`jmlr:v17/15-239.html <http://jmlr.org/papers/v17/15-239.html>`__
- Generalizing Across Domains via Cross-Gradient Training
(``CrossGradSolver``),
`arxiv:1708.00938 <http://arxiv.org/abs/1804.10745>`__
- Adversarial Dropout Regularization (``AdversarialDropoutSolver``),
`arxiv.org:1711.01575 <https://arxiv.org/abs/1711.01575>`__
Implements the following features (in ``salad.layers``):
- Weights Ensembling using Exponential Moving Averages or Stored
Weights
- WalkerLoss and Visit Loss
(`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__)
- Virtual Adversarial Training
(`arxiv:1704.03976 <https://arxiv.org/abs/1704.03976>`__)
Coming soon:
- Deep Joint Optimal Transport (``DJDOTSolver``),
`arxiv:1803.10081 <https://arxiv.org/abs/1803.10081>`__
- Translation based approaches
📊 Benchmarking Results
----------------------
One of salad's purposes is to constantly track the state of the art of a variety of domain
adaptation algorithms. The latest results can be reproduced by the files in the ``scripts/``
directory.
.. figure:: img/benchmarks.svg
:alt:
💻 Installation
---------------
Requirements can be found in ``requirement.txt`` and can be installed
via
.. code:: bash
pip install -r requirements.txt
Install the package via
.. code:: bash
pip install torch-salad
For the latest development version, install via
.. code:: bash
pip install git+https://github.com/bethgelab/domainadaptation
📚 Using this library
---------------------
Along with the implementation of domain adaptation routines, this
library comprises code to easily set up deep learning experiments in
general.
This section will be extended upon pre-release.
Quick Start
~~~~~~~~~~~
To get started, the ``scripts/`` directory contains several python scripts
for both running replication studies on digit benchmarks and studies on
a different dataset (toy example: adaptation to noisy images).
.. code:: bash
$ cd scripts
$ python train_digits.py --log ./log --teach --source svhn --target mnist
Refer to the help pages for all options:
.. code::
usage: train_digits.py [-h] [--gpu GPU] [--cpu] [--njobs NJOBS] [--log LOG]
[--epochs EPOCHS] [--checkpoint CHECKPOINT]
[--learningrate LEARNINGRATE] [--dryrun]
[--source {mnist,svhn,usps,synth,synth-small}]
[--target {mnist,svhn,usps,synth,synth-small}]
[--sourcebatch SOURCEBATCH] [--targetbatch TARGETBATCH]
[--seed SEED] [--print] [--null] [--adv] [--vada]
[--dann] [--assoc] [--coral] [--teach]
Domain Adaptation Comparision and Reproduction Study
optional arguments:
-h, --help show this help message and exit
--gpu GPU Specify GPU
--cpu Use CPU Training
--njobs NJOBS Number of processes per dataloader
--log LOG Log directory. Will be created if non-existing
--epochs EPOCHS Number of Epochs (Full passes through the unsupervised
training set)
--checkpoint CHECKPOINT
Checkpoint path
--learningrate LEARNINGRATE
Learning rate for Adam. Defaults to Karpathy's
constant ;-)
--dryrun Perform a test run, without actually training a
network.
--source {mnist,svhn,usps,synth,synth-small}
Source Dataset. Choose mnist or svhn
--target {mnist,svhn,usps,synth,synth-small}
Target Dataset. Choose mnist or svhn
--sourcebatch SOURCEBATCH
Batch size of Source
--targetbatch TARGETBATCH
Batch size of Target
--seed SEED Random Seed
--print
--null
--adv Train a model with Adversarial Domain Regularization
--vada Train a model with Virtual Adversarial Domain
Adaptation
--dann Train a model with Domain Adversarial Training
--assoc Train a model with Associative Domain Adaptation
--coral Train a model with Deep Correlation Alignment
--teach Train a model with Self-Ensembling
Reasons for using solver abstractions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The chosen abstraction style organizes experiments into a subclass of
``Solver``.
Quickstart: MNIST Experiment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As a quick MNIST experiment:
.. code:: python
from salad.solvers import Solver
class MNISTSolver(Solver):
def __init__(self, model, dataset, **kwargs):
self.model = model
super().__init__(dataset, **kwargs)
def _init_optims(self, lr = 1e-4, **kwargs):
super()._init_optims(**kwargs)
opt = torch.optim.Adam(self.model.parameters(), lr = lr)
self.register_optimizer(opt)
def _init_losses(self):
pass
For a simple tasks as MNIST, the code is quite long compared to other
PyTorch examples `TODO <#>`__.
💡 Domain Adaptation Problems
-----------------------------
Legend: Implemented (✓), Under Construction (🚧)
📷 Vision
~~~~~~~~~
- Digits: MNIST ↔ SVHN ↔ USPS ↔ SYNTH (✓)
- `VisDA 2018 Openset and Detection <http://ai.bu.edu/visda-2018>`__
(✓)
- Synthetic (GAN) ↔ Real (🚧)
- CIFAR ↔ STL (🚧)
- ImageNet to
`iCubWorld <https://robotology.github.io/iCubWorld/#datasets>`__ (🚧)
🎤 Audio
~~~~~~~~
- `Mozilla Common Voice Dataset <https://voice.mozilla.org/>`__ (🚧)
፨ Neuroscience
~~~~~~~~~~~~~~
- White Noise ↔ Gratings ↔ Natural Images (🚧)
- `Deep Lab Cut Tracking <https://github.com/AlexEMG/DeepLabCut>`__ (🚧)
🔗 References to open source software
-------------------------------------
Part of the code in this repository is inspired or borrowed from
original implementations, especially:
- https://github.com/Britefury/self-ensemble-visual-domain-adapt
- https://github.com/Britefury/self-ensemble-visual-domain-adapt-photo/
- https://github.com/RuiShu/dirt-t
- https://github.com/gpascualg/CrossGrad
- https://github.com/stes/torch-associative
- https://github.com/haeusser/learning\_by\_association
- https://mil-tokyo.github.io/adr\_da/
Excellent list of domain adaptation ressources: -
https://github.com/artix41/awesome-transfer-learning
👤 Contact
----------
Maintained by `Steffen Schneider <https://code.stes.io>`__. Work is part
of my thesis project at the `Bethge Lab <http://bethgelab.org>`__. This
README is also available as a webpage at
`salad.domainadaptation.org <http://salad.domainadaptation.org>`__. We
welcome issues and pull requests `to the official github
repository <https://github.com/bethgelab/domainadaptation>`__.
========
**S**\ emi-supervised **A**\ daptive **L**\ earning **A**\ cross **D**\ omains
.. figure:: img/domainshift.png
:alt:
``salad`` is a library to easily setup experiments using the current
state-of-the art techniques in domain adaptation. It features several of
recent approaches, with the goal of being able to run fair comparisons
between algorithms and transfer them to real-world use cases. The
toolbox is under active development and will extended when new
approaches are published.
Contribute on Github: `https://github.com/domainadaptation/salad`_
Currently implements the following techniques (in ``salad.solver``)
- VADA (``VADASolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- DIRT-T (``DIRTTSolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- Self-Ensembling for Visual Domain Adaptation
(``SelfEnsemblingSolver``)
`arxiv:1706.05208 <https://arxiv.org/abs/1706.05208>`__
- Associative Domain Adaptation (``AssociativeSolver``),
`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__
- Domain Adversarial Training (``DANNSolver``),
`jmlr:v17/15-239.html <http://jmlr.org/papers/v17/15-239.html>`__
- Generalizing Across Domains via Cross-Gradient Training
(``CrossGradSolver``),
`arxiv:1708.00938 <http://arxiv.org/abs/1804.10745>`__
- Adversarial Dropout Regularization (``AdversarialDropoutSolver``),
`arxiv.org:1711.01575 <https://arxiv.org/abs/1711.01575>`__
Implements the following features (in ``salad.layers``):
- Weights Ensembling using Exponential Moving Averages or Stored
Weights
- WalkerLoss and Visit Loss
(`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__)
- Virtual Adversarial Training
(`arxiv:1704.03976 <https://arxiv.org/abs/1704.03976>`__)
Coming soon:
- Deep Joint Optimal Transport (``DJDOTSolver``),
`arxiv:1803.10081 <https://arxiv.org/abs/1803.10081>`__
- Translation based approaches
📊 Benchmarking Results
----------------------
One of salad's purposes is to constantly track the state of the art of a variety of domain
adaptation algorithms. The latest results can be reproduced by the files in the ``scripts/``
directory.
.. figure:: img/benchmarks.svg
:alt:
💻 Installation
---------------
Requirements can be found in ``requirement.txt`` and can be installed
via
.. code:: bash
pip install -r requirements.txt
Install the package via
.. code:: bash
pip install torch-salad
For the latest development version, install via
.. code:: bash
pip install git+https://github.com/bethgelab/domainadaptation
📚 Using this library
---------------------
Along with the implementation of domain adaptation routines, this
library comprises code to easily set up deep learning experiments in
general.
This section will be extended upon pre-release.
Quick Start
~~~~~~~~~~~
To get started, the ``scripts/`` directory contains several python scripts
for both running replication studies on digit benchmarks and studies on
a different dataset (toy example: adaptation to noisy images).
.. code:: bash
$ cd scripts
$ python train_digits.py --log ./log --teach --source svhn --target mnist
Refer to the help pages for all options:
.. code::
usage: train_digits.py [-h] [--gpu GPU] [--cpu] [--njobs NJOBS] [--log LOG]
[--epochs EPOCHS] [--checkpoint CHECKPOINT]
[--learningrate LEARNINGRATE] [--dryrun]
[--source {mnist,svhn,usps,synth,synth-small}]
[--target {mnist,svhn,usps,synth,synth-small}]
[--sourcebatch SOURCEBATCH] [--targetbatch TARGETBATCH]
[--seed SEED] [--print] [--null] [--adv] [--vada]
[--dann] [--assoc] [--coral] [--teach]
Domain Adaptation Comparision and Reproduction Study
optional arguments:
-h, --help show this help message and exit
--gpu GPU Specify GPU
--cpu Use CPU Training
--njobs NJOBS Number of processes per dataloader
--log LOG Log directory. Will be created if non-existing
--epochs EPOCHS Number of Epochs (Full passes through the unsupervised
training set)
--checkpoint CHECKPOINT
Checkpoint path
--learningrate LEARNINGRATE
Learning rate for Adam. Defaults to Karpathy's
constant ;-)
--dryrun Perform a test run, without actually training a
network.
--source {mnist,svhn,usps,synth,synth-small}
Source Dataset. Choose mnist or svhn
--target {mnist,svhn,usps,synth,synth-small}
Target Dataset. Choose mnist or svhn
--sourcebatch SOURCEBATCH
Batch size of Source
--targetbatch TARGETBATCH
Batch size of Target
--seed SEED Random Seed
--null
--adv Train a model with Adversarial Domain Regularization
--vada Train a model with Virtual Adversarial Domain
Adaptation
--dann Train a model with Domain Adversarial Training
--assoc Train a model with Associative Domain Adaptation
--coral Train a model with Deep Correlation Alignment
--teach Train a model with Self-Ensembling
Reasons for using solver abstractions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The chosen abstraction style organizes experiments into a subclass of
``Solver``.
Quickstart: MNIST Experiment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As a quick MNIST experiment:
.. code:: python
from salad.solvers import Solver
class MNISTSolver(Solver):
def __init__(self, model, dataset, **kwargs):
self.model = model
super().__init__(dataset, **kwargs)
def _init_optims(self, lr = 1e-4, **kwargs):
super()._init_optims(**kwargs)
opt = torch.optim.Adam(self.model.parameters(), lr = lr)
self.register_optimizer(opt)
def _init_losses(self):
pass
For a simple tasks as MNIST, the code is quite long compared to other
PyTorch examples `TODO <#>`__.
💡 Domain Adaptation Problems
-----------------------------
Legend: Implemented (✓), Under Construction (🚧)
📷 Vision
~~~~~~~~~
- Digits: MNIST ↔ SVHN ↔ USPS ↔ SYNTH (✓)
- `VisDA 2018 Openset and Detection <http://ai.bu.edu/visda-2018>`__
(✓)
- Synthetic (GAN) ↔ Real (🚧)
- CIFAR ↔ STL (🚧)
- ImageNet to
`iCubWorld <https://robotology.github.io/iCubWorld/#datasets>`__ (🚧)
🎤 Audio
~~~~~~~~
- `Mozilla Common Voice Dataset <https://voice.mozilla.org/>`__ (🚧)
፨ Neuroscience
~~~~~~~~~~~~~~
- White Noise ↔ Gratings ↔ Natural Images (🚧)
- `Deep Lab Cut Tracking <https://github.com/AlexEMG/DeepLabCut>`__ (🚧)
🔗 References to open source software
-------------------------------------
Part of the code in this repository is inspired or borrowed from
original implementations, especially:
- https://github.com/Britefury/self-ensemble-visual-domain-adapt
- https://github.com/Britefury/self-ensemble-visual-domain-adapt-photo/
- https://github.com/RuiShu/dirt-t
- https://github.com/gpascualg/CrossGrad
- https://github.com/stes/torch-associative
- https://github.com/haeusser/learning\_by\_association
- https://mil-tokyo.github.io/adr\_da/
Excellent list of domain adaptation ressources: -
https://github.com/artix41/awesome-transfer-learning
👤 Contact
----------
Maintained by `Steffen Schneider <https://code.stes.io>`__. Work is part
of my thesis project at the `Bethge Lab <http://bethgelab.org>`__. This
README is also available as a webpage at
`salad.domainadaptation.org <http://salad.domainadaptation.org>`__. We
welcome issues and pull requests `to the official github
repository <https://github.com/bethgelab/domainadaptation>`__.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch-salad-0.2.1a0.tar.gz
(50.3 kB
view details)
Built Distribution
File details
Details for the file torch-salad-0.2.1a0.tar.gz
.
File metadata
- Download URL: torch-salad-0.2.1a0.tar.gz
- Upload date:
- Size: 50.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.3.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 35222526e6593bfd19c14c5896764e36bd805c9702685e93b081de1f438aacd1 |
|
MD5 | 7270bea56d0f177a2343d5126ca11377 |
|
BLAKE2b-256 | 00992278637de69ebb64f5cb23134516ade17b2d16a0fc93e2ff2f0f91454dc5 |
File details
Details for the file torch_salad-0.2.1a0-py3-none-any.whl
.
File metadata
- Download URL: torch_salad-0.2.1a0-py3-none-any.whl
- Upload date:
- Size: 82.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.3.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93d7c5071a84a851d8bdbc753cb48ac1bdb1baf03517a7c0aa19e31911b7384a |
|
MD5 | 0fdb0dbb851be484f2c9102210d35b8a |
|
BLAKE2b-256 | db37a8614582b1b7be0c2b2d145c8af7c92417346b65faf1b6b32ab30355d119 |