No project description provided
Project description
Torch Mutable Modules
Use in-place and assignment operations on PyTorch module parameters with support for autograd.
Why does this exist?
PyTorch does not allow in-place operations on module parameters (usually desirable):
linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1
) that predicts the parameter values to another neural network (net_2
), we need to be able to modify the weights of net_2
in-place and backpropagate the gradients to net_1
.
# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)
# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]
# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))
# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias
# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()
This library provides a way to easily convert PyTorch modules into mutable modules with the to_mutable_module
function.
Installation
You can install torch-mutable-modules
from PyPI.
pip install torch-mutable-modules
To upgrade an existing installation of torch-mutable-modules
, use the following command:
pip install --upgrade --no-cache-dir torch-mutable-modules
Importing
You can use wildcard imports or import specific functions directly:
# import all functions
from torch_mutable_modules import *
# ... or import the function manually
from torch_mutable_modules import to_mutable_module
Usage
To convert an existing PyTorch module into a mutable module, use the to_mutable_module
function:
converted_module = to_mutable_module(
torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear
converted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=<AddBackward0>)
You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=<AddBackward0>)
Usage with CUDA
To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the to_mutable_module
function:
converted_module = to_mutable_module(
torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU
Moving a module to the GPU with .to()
and .cuda()
after instanciation is NOT supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.
# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")
Detailed examples
Please check out example.py to see more detailed example usages of the to_mutable_module
function.
Contributing
Please feel free to submit issues or pull requests!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_mutable_modules-1.1.0.tar.gz
.
File metadata
- Download URL: torch_mutable_modules-1.1.0.tar.gz
- Upload date:
- Size: 4.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0692510fdbc6064a3112b27b329b7d612121c7c08168aaba53840a50ef127a8f |
|
MD5 | 47535781e7295c5b12efcfba9bb7f489 |
|
BLAKE2b-256 | 03962c4c3b4c7b8da9f5271e4c7f9aeb0fa845858285a9a4e9e1896bb9379596 |
File details
Details for the file torch_mutable_modules-1.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_mutable_modules-1.1.0-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 84b74a43731f1382c25e8533bcde0b454940e7706d6f27bf87626c1470956b83 |
|
MD5 | 72663bbfa4b21e6024fca9853024dbee |
|
BLAKE2b-256 | 69ab2a0008e767cee9e884207f1164393c5792f8e5a68f6281059970dd1aafb4 |