Skip to main content

Semi-supervised Adaptive Learning Across Domains

Project description

🥗 salad
========

**S**\ emi-supervised **A**\ daptive **L**\ earning **A**\ cross **D**\ omains

.. figure:: img/domainshift.png
:alt:


``salad`` is a library to easily setup experiments using the current
state-of-the art techniques in domain adaptation. It features several of
recent approaches, with the goal of being able to run fair comparisons
between algorithms and transfer them to real-world use cases. The
toolbox is under active development and will extended when new
approaches are published.

Contribute on Github: `https://github.com/domainadaptation/salad`_

Currently implements the following techniques (in ``salad.solver``)

- VADA (``VADASolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- DIRT-T (``DIRTTSolver``),
`arxiv:1802.08735 <https://arxiv.org/abs/1802.08735>`__
- Self-Ensembling for Visual Domain Adaptation
(``SelfEnsemblingSolver``)
`arxiv:1706.05208 <https://arxiv.org/abs/1706.05208>`__
- Associative Domain Adaptation (``AssociativeSolver``),
`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__
- Domain Adversarial Training (``DANNSolver``),
`jmlr:v17/15-239.html <http://jmlr.org/papers/v17/15-239.html>`__
- Generalizing Across Domains via Cross-Gradient Training
(``CrossGradSolver``),
`arxiv:1708.00938 <http://arxiv.org/abs/1804.10745>`__
- Adversarial Dropout Regularization (``AdversarialDropoutSolver``),
`arxiv.org:1711.01575 <https://arxiv.org/abs/1711.01575>`__

Implements the following features (in ``salad.layers``):

- Weights Ensembling using Exponential Moving Averages or Stored
Weights
- WalkerLoss and Visit Loss
(`arxiv:1708.00938 <https://arxiv.org/pdf/1708.00938.pdf>`__)
- Virtual Adversarial Training
(`arxiv:1704.03976 <https://arxiv.org/abs/1704.03976>`__)

Coming soon:

- Deep Joint Optimal Transport (``DJDOTSolver``),
`arxiv:1803.10081 <https://arxiv.org/abs/1803.10081>`__
- Translation based approaches

📊 Benchmarking Results
----------------------

One of salad's purposes is to constantly track the state of the art of a variety of domain
adaptation algorithms. The latest results can be reproduced by the files in the ``scripts/``
directory.

.. figure:: img/benchmarks.svg
:alt:


💻 Installation
---------------

Requirements can be found in ``requirement.txt`` and can be installed
via

.. code:: bash

pip install -r requirements.txt

Install the package via

.. code:: bash

pip install torch-salad

For the latest development version, install via

.. code:: bash

pip install git+https://github.com/bethgelab/domainadaptation


📚 Using this library
---------------------

Along with the implementation of domain adaptation routines, this
library comprises code to easily set up deep learning experiments in
general.

This section will be extended upon pre-release.

Quick Start
~~~~~~~~~~~

To get started, the ``scripts/`` directory contains several python scripts
for both running replication studies on digit benchmarks and studies on
a different dataset (toy example: adaptation to noisy images).

.. code:: bash

$ cd scripts
$ python train_digits.py --log ./log --teach --source svhn --target mnist

Refer to the help pages for all options:

.. code::

usage: train_digits.py [-h] [--gpu GPU] [--cpu] [--njobs NJOBS] [--log LOG]
[--epochs EPOCHS] [--checkpoint CHECKPOINT]
[--learningrate LEARNINGRATE] [--dryrun]
[--source {mnist,svhn,usps,synth,synth-small}]
[--target {mnist,svhn,usps,synth,synth-small}]
[--sourcebatch SOURCEBATCH] [--targetbatch TARGETBATCH]
[--seed SEED] [--print] [--null] [--adv] [--vada]
[--dann] [--assoc] [--coral] [--teach]

Domain Adaptation Comparision and Reproduction Study

optional arguments:
-h, --help show this help message and exit
--gpu GPU Specify GPU
--cpu Use CPU Training
--njobs NJOBS Number of processes per dataloader
--log LOG Log directory. Will be created if non-existing
--epochs EPOCHS Number of Epochs (Full passes through the unsupervised
training set)
--checkpoint CHECKPOINT
Checkpoint path
--learningrate LEARNINGRATE
Learning rate for Adam. Defaults to Karpathy's
constant ;-)
--dryrun Perform a test run, without actually training a
network.
--source {mnist,svhn,usps,synth,synth-small}
Source Dataset. Choose mnist or svhn
--target {mnist,svhn,usps,synth,synth-small}
Target Dataset. Choose mnist or svhn
--sourcebatch SOURCEBATCH
Batch size of Source
--targetbatch TARGETBATCH
Batch size of Target
--seed SEED Random Seed
--print
--null
--adv Train a model with Adversarial Domain Regularization
--vada Train a model with Virtual Adversarial Domain
Adaptation
--dann Train a model with Domain Adversarial Training
--assoc Train a model with Associative Domain Adaptation
--coral Train a model with Deep Correlation Alignment
--teach Train a model with Self-Ensembling



Reasons for using solver abstractions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The chosen abstraction style organizes experiments into a subclass of
``Solver``.

Quickstart: MNIST Experiment
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As a quick MNIST experiment:

.. code:: python

from salad.solvers import Solver

class MNISTSolver(Solver):

def __init__(self, model, dataset, **kwargs):

self.model = model
super().__init__(dataset, **kwargs)

def _init_optims(self, lr = 1e-4, **kwargs):
super()._init_optims(**kwargs)

opt = torch.optim.Adam(self.model.parameters(), lr = lr)
self.register_optimizer(opt)

def _init_losses(self):
pass

For a simple tasks as MNIST, the code is quite long compared to other
PyTorch examples `TODO <#>`__.

💡 Domain Adaptation Problems
-----------------------------

Legend: Implemented (✓), Under Construction (🚧)

📷 Vision
~~~~~~~~~

- Digits: MNIST ↔ SVHN ↔ USPS ↔ SYNTH (✓)
- `VisDA 2018 Openset and Detection <http://ai.bu.edu/visda-2018>`__
(✓)
- Synthetic (GAN) ↔ Real (🚧)
- CIFAR ↔ STL (🚧)
- ImageNet to
`iCubWorld <https://robotology.github.io/iCubWorld/#datasets>`__ (🚧)

🎤 Audio
~~~~~~~~

- `Mozilla Common Voice Dataset <https://voice.mozilla.org/>`__ (🚧)

፨ Neuroscience
~~~~~~~~~~~~~~

- White Noise ↔ Gratings ↔ Natural Images (🚧)
- `Deep Lab Cut Tracking <https://github.com/AlexEMG/DeepLabCut>`__ (🚧)

🔗 References to open source software
-------------------------------------

Part of the code in this repository is inspired or borrowed from
original implementations, especially:

- https://github.com/Britefury/self-ensemble-visual-domain-adapt
- https://github.com/Britefury/self-ensemble-visual-domain-adapt-photo/
- https://github.com/RuiShu/dirt-t
- https://github.com/gpascualg/CrossGrad
- https://github.com/stes/torch-associative
- https://github.com/haeusser/learning\_by\_association
- https://mil-tokyo.github.io/adr\_da/

Excellent list of domain adaptation ressources: -
https://github.com/artix41/awesome-transfer-learning

👤 Contact
----------

Maintained by `Steffen Schneider <https://code.stes.io>`__. Work is part
of my thesis project at the `Bethge Lab <http://bethgelab.org>`__. This
README is also available as a webpage at
`salad.domainadaptation.org <http://salad.domainadaptation.org>`__. We
welcome issues and pull requests `to the official github
repository <https://github.com/bethgelab/domainadaptation>`__.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-salad-0.2.1a0.tar.gz (50.3 kB view details)

Uploaded Source

Built Distribution

torch_salad-0.2.1a0-py3-none-any.whl (82.7 kB view details)

Uploaded Python 3

File details

Details for the file torch-salad-0.2.1a0.tar.gz.

File metadata

  • Download URL: torch-salad-0.2.1a0.tar.gz
  • Upload date:
  • Size: 50.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.3.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.7.0

File hashes

Hashes for torch-salad-0.2.1a0.tar.gz
Algorithm Hash digest
SHA256 35222526e6593bfd19c14c5896764e36bd805c9702685e93b081de1f438aacd1
MD5 7270bea56d0f177a2343d5126ca11377
BLAKE2b-256 00992278637de69ebb64f5cb23134516ade17b2d16a0fc93e2ff2f0f91454dc5

See more details on using hashes here.

File details

Details for the file torch_salad-0.2.1a0-py3-none-any.whl.

File metadata

  • Download URL: torch_salad-0.2.1a0-py3-none-any.whl
  • Upload date:
  • Size: 82.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.3.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.7.0

File hashes

Hashes for torch_salad-0.2.1a0-py3-none-any.whl
Algorithm Hash digest
SHA256 93d7c5071a84a851d8bdbc753cb48ac1bdb1baf03517a7c0aa19e31911b7384a
MD5 0fdb0dbb851be484f2c9102210d35b8a
BLAKE2b-256 db37a8614582b1b7be0c2b2d145c8af7c92417346b65faf1b6b32ab30355d119

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page