Skip to main content

Dirty-MNIST from "Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty"

Project description

DDU's Dirty-MNIST

You'll never want to use MNIST again for OOD or AL.

arXiv PyPI Pytorch 1.8.1 License: Apache

This repository contains the Dirty-MNIST dataset described in Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty.

The official repository for the paper is at https://github.com/omegafragger/DDU.

If the code or the paper has been useful in your research, please add a citation to our work:

@article{mukhoti2021deterministic,
  title={Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty},
  author={Mukhoti, Jishnu and Kirsch, Andreas and van Amersfoort, Joost and Torr, Philip HS and Gal, Yarin},
  journal={arXiv preprint arXiv:2102.11582},
  year={2021}
}

DirtyMNIST is a concatenation of MNIST and AmbiguousMNIST, with 60k sample-label pairs each in the training set. AmbiguousMNIST contains generated ambiguous MNIST samples with varying entropies: 6k unique samples with 10 labels each.

AmbiguousMNIST Digits from each class with increasing entropy


Install

pip install ddu_dirty_mnist

How to use

After installing, you get a Dirty-MNIST train or test set just like you would for MNIST in PyTorch.

# gpu

import ddu_dirty_mnist

dirty_mnist_train = ddu_dirty_mnist.DirtyMNIST(".", train=True, download=True, device="cuda")
dirty_mnist_test = ddu_dirty_mnist.DirtyMNIST(".", train=False, download=True, device="cuda")
len(dirty_mnist_train), len(dirty_mnist_test)
(120000, 70000)

Create torch.utils.data.DataLoaders with num_workers=0, pin_memory=False for maximum throughput, see the documentation for details.

# gpu
import torch

dirty_mnist_train_dataloader = torch.utils.data.DataLoader(
    dirty_mnist_train,
    batch_size=128,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
)
dirty_mnist_test_dataloader = torch.utils.data.DataLoader(
    dirty_mnist_test,
    batch_size=128,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
)

Ambiguous-MNIST

If you only care about Ambiguous-MNIST, you can use:

# gpu

import ddu_dirty_mnist

ambiguous_mnist_train = ddu_dirty_mnist.AmbiguousMNIST(".", train=True, download=True, device="cuda")
ambiguous_mnist_test = ddu_dirty_mnist.AmbiguousMNIST(".", train=False, download=True, device="cuda")

ambiguous_mnist_train, ambiguous_mnist_test
(Dataset AmbiguousMNIST
     Number of datapoints: 60000
     Root location: .,
 Dataset AmbiguousMNIST
     Number of datapoints: 60000
     Root location: .)

Again, create torch.utils.data.DataLoaders with num_workers=0, pin_memory=False for maximum throughput, see the documentation for details.

# gpu
import torch

ambiguous_mnist_train_dataloader = torch.utils.data.DataLoader(
    ambiguous_mnist_train,
    batch_size=128,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
)
ambiguous_mnist_test_dataloader = torch.utils.data.DataLoader(
    ambiguous_mnist_test,
    batch_size=128,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
)

Additional Guidance

  1. The current AmbiguousMNIST contains 6k unique samples with 10 labels each. This multi-label dataset gets flattened to 60k samples. The assumption is that amibguous samples have multiple "valid" labels as they are ambiguous. MNIST samples are intentionally undersampled (in comparison), which benefits AL acquisition functions that can select unambiguous samples.
  2. Pick your initial training samples (for warm starting Active Learning) from the MNIST half of DirtyMNIST to avoid starting training with potentially very ambiguous samples, which might add a lot of variance to your experiments.
  3. Make sure to pick your validation set from the MNIST half as well, for the same reason as above.
  4. Make sure that your batch acquisition size is >= 10 (probably) given that there are 10 multi-labels per samples in Ambiguous-MNIST.
  5. By default, Gaussian noise with stddev 0.05 is added to each sample to prevent acquisition functions from cheating by disgarding "duplicates".
  6. If you want to split Ambiguous-MNIST into subsets (or Dirty-MNIST within the second ambiguous half), make sure to split by multiples of 10 to avoid splits within a flattened multi-label sample.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ddu_dirty_mnist-1.1.1.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ddu_dirty_mnist-1.1.1-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file ddu_dirty_mnist-1.1.1.tar.gz.

File metadata

  • Download URL: ddu_dirty_mnist-1.1.1.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for ddu_dirty_mnist-1.1.1.tar.gz
Algorithm Hash digest
SHA256 e92e4c72c89210c5b64eed154209e5b5620e64379a223c6cd0b1dd01bf819eb8
MD5 0069c0b313def432de7ddb67b56d286e
BLAKE2b-256 22abd03fa2541328b9a9bd792ed2812de287f6cc6768bb8bfa35f498898790e9

See more details on using hashes here.

File details

Details for the file ddu_dirty_mnist-1.1.1-py3-none-any.whl.

File metadata

  • Download URL: ddu_dirty_mnist-1.1.1-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.8.8

File hashes

Hashes for ddu_dirty_mnist-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 54df9287c269a9b6b88e81816d32b5ca6951f1fcb5d17a133e5af4f2ecfed031
MD5 adf02ce262424d2c7d980a8c2270fdb5
BLAKE2b-256 0ff969a9c59be11c712c44dd5019bc2e4a71254a84b2b2904630b0bad073ae58

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page