Skip to main content

No project description provided

Project description

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = convert_to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the convert_to_mutable_module function and the @mutable_module decorator.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import them manually
from torch_mutable_modules import convert_to_mutable_module, mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the convert_to_mutable_module function:

converted_module = convert_to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

You can also declare entire PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

@mutable_module
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

More examples

Check out example.py to see more example usages of the convert_to_mutable_module function and the @mutable_module decorator.

Contributing

Please feel free to submit issues or pull requests!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mutable_modules-1.0.0.tar.gz (4.2 kB view details)

Uploaded Source

Built Distribution

torch_mutable_modules-1.0.0-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_mutable_modules-1.0.0.tar.gz.

File metadata

  • Download URL: torch_mutable_modules-1.0.0.tar.gz
  • Upload date:
  • Size: 4.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for torch_mutable_modules-1.0.0.tar.gz
Algorithm Hash digest
SHA256 ef6391b72ea7aa1dbd0a693e7d9c49e1feedc02785f6df032556a9cd425d1630
MD5 0cb21c1d56eaab0fab981fd49a771043
BLAKE2b-256 64b8448c76d161ecde4d93e77fd62c522978150867caa873eb57d5fd80f8e8b8

See more details on using hashes here.

File details

Details for the file torch_mutable_modules-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_mutable_modules-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for torch_mutable_modules-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cd1c20d7b436c99ee055f3c27ecf2988fb18b13b2ef400a2ce1b930f105f4ca7
MD5 d9cc8206ccc0ef99bb62ed62099244fe
BLAKE2b-256 0a77e670dee17e63dafcaf6ef8cb885d8892cbfde1872d4a69b2c26725a9b857

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page