Skip to main content

No project description provided

Project description

Torch Mutable Modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.

Publish to PyPI Run tests PyPI version Number of downloads from PyPI per month Python version support Code Style: Black

Why does this exist?

PyTorch does not allow in-place operations on module parameters (usually desirable):

linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1) that predicts the parameter values to another neural network (net_2), we need to be able to modify the weights of net_2 in-place and backpropagate the gradients to net_1.

# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)
# create an optimizer for net_1
optimizer_1 = torch.optim.SGD(net_1.parameters(), lr=0.01)

# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]

# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))

# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias

# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()

This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module function.

Installation

You can install torch-mutable-modules from PyPI.

pip install torch-mutable-modules

To upgrade an existing installation of torch-mutable-modules, use the following command:

pip install --upgrade --no-cache-dir torch-mutable-modules

Importing

You can use wildcard imports or import specific functions directly:

# import all functions
from torch_mutable_modules import *

# ... or import the function manually
from torch_mutable_modules import to_mutable_module

Usage

To convert an existing PyTorch module into a mutable module, use the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear

converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)

You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)

Usage with CUDA

To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module function:

converted_module = to_mutable_module(
    torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU

Moving a module to the GPU with .to() and .cuda() after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.

# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")

Detailed examples

Please check out example.py to see more detailed example usages of the to_mutable_module function.

Contributing

Please feel free to submit issues or pull requests!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mutable_modules-1.1.2.tar.gz (4.4 kB view details)

Uploaded Source

Built Distribution

torch_mutable_modules-1.1.2-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file torch_mutable_modules-1.1.2.tar.gz.

File metadata

  • Download URL: torch_mutable_modules-1.1.2.tar.gz
  • Upload date:
  • Size: 4.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.12

File hashes

Hashes for torch_mutable_modules-1.1.2.tar.gz
Algorithm Hash digest
SHA256 01cd2d1f37f06489ab086cb1e4064cf16cec9ddc59968feb35c74b43aee23a18
MD5 e4078c5bdcae1467a7b5fa3864a6933d
BLAKE2b-256 2ceb4a4792986f0f7d688a46deefa34dbeb06dc6135e63c257978319386191ad

See more details on using hashes here.

File details

Details for the file torch_mutable_modules-1.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_mutable_modules-1.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 114616174f1215c7bbb39c2f4e4bcf393d231e8dbba1e2719a7e70e8087f5606
MD5 f168b9fcdbee8106be470d8c977ac0bd
BLAKE2b-256 0eae9b455c0bbc40e0c971bc96b8a94d4178ba4814d7a719b70d342c6f5a7845

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page