Skip to main content

No project description provided

Project description

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = convert_to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the convert_to_mutable_module function and the @mutable_module decorator.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import them manually
from torch_mutable_modules import convert_to_mutable_module, mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the convert_to_mutable_module function:

converted_module = convert_to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare entire PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

@mutable_module
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

my_module = MyModule()
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Detailed examples

Please check out example.py to see more detailed example usages of the convert_to_mutable_module function and the @mutable_module decorator.

Contributing

Please feel free to submit issues or pull requests!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mutable_modules-1.0.1.tar.gz (4.3 kB view details)

Uploaded Source

Built Distribution

torch_mutable_modules-1.0.1-py3-none-any.whl (4.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_mutable_modules-1.0.1.tar.gz.

File metadata

  • Download URL: torch_mutable_modules-1.0.1.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for torch_mutable_modules-1.0.1.tar.gz
Algorithm Hash digest
SHA256 577c5a30be52a1523938b1cd2c0891a2987b57f42cabefc227882f78c22395f9
MD5 d491977912fa73a7ee29fbf212645312
BLAKE2b-256 eb39af2f6f531349488f36842265c1831121c08b57929d2d2e285023027fcc32

See more details on using hashes here.

File details

Details for the file torch_mutable_modules-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_mutable_modules-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 4.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for torch_mutable_modules-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 646be60194f28e7e0b78483e08462255d92677601b3f73ff57d1f64b4f146a80
MD5 082bc189a27590a0825462624879c479
BLAKE2b-256 0cd3cd01a7863e3fa0ad4cb217b172a41c92b29dc4107e16b6a25492cd612953

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page