Skip to main content

Control random number generator state via reproducible blocks

Project description

Torch-reproducible-block

License: GPL v3

Random number generation is hard to control when experimenting with neural networks. Setting a random seed only get us so far. Each random operation affects the random number generator state.

Changes to the model hyper-parameters or architecture can affect how each layer is initialised, the regularisation techniques, how the data is presented to the network during training and more.

This package aims to reduce variability by limiting side effects caused by random number generation. The main goal is to limit the changes to the rest of the network when experimenting with different hyper-parameters.

What is the problem ?

The weight initialisation of a layer constitute a random operation. The initialisation order therefore have an impact on subsequent layers.

Problem Definition

In this small toy model, the initial weights of the fully connected layers will be different if we have a different number of convolutive layers. The initialisation of a pre-trained feature extractor might also comprises random operation which will affect the rest of the network. A different random state will also affect the dataloading process since it also rely on random operations to select random examples when creating batches.

Solution

We isolate different parts of the network by wrapping them inside a Reproducible_Block

Reproducible Block Solution

How does it work ?

Reproducible_Block.set_seed(SEED_VALUE) must be called before any random operation. This will set the python, numpy and torch seeds and will save a copy of the initial random number generator state.

When entering a Reproducible_Block the random number generator state is reset to the initial state. The state is then mutated according to the Block Seed value ensuring that each block have a different state. To mutate the state, we simply run X random operations where X is Block Seed.

Feel free to take at look at the code, it's only about 100 lines.

Installation

The package can be installed via pip :

pip install torch-reproducible-block

The following packages are required :

torch
numpy

Usage

from reproducible_block import Reproducible_Block

class My_model(nn.Module):
    def __init__(self):
        with Reproducible_Block(block_seed=64):
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=[2,2])
            self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=[2,2])

        with Reproducible_Block(block_seed=44):
            self.fc = nn.Linear(64, 128)
            self.out = nn.Linear(128, 10)
            
    def forward(self, batch):
        ...
    
if __name__ == "__main__":
    # Will set seed and save the initial random state
    Reproducible_Block.set_seed(42)
    
    model = My_model()
    # Data loading and other configurations which might do random operations....
    
    # Ensure that we always have the same random state when starting training
    with Reproducible_Block(block_seed=128):
        train_model(model, data, ...)

Reproducible block can also be used as a function decorator :

@Reproducible_Block(block_seed=128)
def train_model(model, dataloader, ...):
    ...

Remarks

  • Using a different initial seed (Via Reproducible_Block.set_seed()), will result in a different random state for each Reproducible_Block.
  • The block seed is "part of the code". You should not attempt to tweak it the way we do with "normal seed". Changing the initial seed is what you want to do in order to create different initial conditions for your training.
  • Using the same block seed for different Reproducible_Block will result in the same random state. Make sure that you are using a different block seed for each block.
  • Was tested on Ubuntu 18.04 with python 3.6. Should not have any problems running on other platforms. Fill up an issue if you have any problems.

Other sources of randomness

This package won't make your research 100% reproducible. It simply aim to isolate part of your program to side effects.

  • Setting the PYTHONHASHSEED environment variable is always a good idea.
  • The way python read files from directories (ex : os.listdir()) can vary when ran on different OS (Linux vs Windows)
  • Different version of python, Cuda, Cudnn and libraries can affect reproducibility

Contributing

The code is pretty simple, have a look at it and if you have improvements ideas feel free to submit a pull request !

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-reproducible-block-0.0.1.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

torch_reproducible_block-0.0.1-py3-none-any.whl (17.0 kB view details)

Uploaded Python 3

File details

Details for the file torch-reproducible-block-0.0.1.tar.gz.

File metadata

  • Download URL: torch-reproducible-block-0.0.1.tar.gz
  • Upload date:
  • Size: 16.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.6.9

File hashes

Hashes for torch-reproducible-block-0.0.1.tar.gz
Algorithm Hash digest
SHA256 cbaa93e2377ecc18ad974d116a3018349461d885e2cc6d4f011bf25b2beed496
MD5 a00fc42aa3a86c661a5858bf7d02ce1a
BLAKE2b-256 15d1892edce340cf14670ecc1a0bdddca2d9603218d561746f05b68e167c769d

See more details on using hashes here.

File details

Details for the file torch_reproducible_block-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_reproducible_block-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 17.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.6.9

File hashes

Hashes for torch_reproducible_block-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5085e0b9f992365645481febbed045760a71f308c268fb406e03a695ccea6a76
MD5 c8d146c0bd96606bdca6f12feab39314
BLAKE2b-256 17b29bcd58d55bb7ee35d6a5bd62149d810a7075e4ba202524cb1b98493bf340

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page