No project description provided
Project description
Torch Mutable Modules
Use in-place and assignment operations on PyTorch module parameters with support for autograd.
Why does this exist?
PyTorch does not allow in-place operations on module parameters (usually desirable):
linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=<AddBackward0>
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (net_1
) that predicts the parameter values to another neural network (net_2
), we need to be able to modify the weights of net_2
in-place and backpropagate the gradients to net_1
.
# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)
# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]
# create a mutable network (net_2)
net_2 = convert_to_mutable_module(torch.nn.Linear(1, 1))
# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias
# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()
This library provides a way to easily convert PyTorch modules into mutable modules with the convert_to_mutable_module
function and the @mutable_module
decorator.
Installation
You can install torch-mutable-modules
from PyPI.
pip install torch-mutable-modules
To upgrade an existing installation of torch-mutable-modules
, use the following command:
pip install --upgrade --no-cache-dir torch-mutable-modules
Importing
You can use wildcard imports or import specific functions directly:
# import all functions
from torch_mutable_modules import *
# ... or import them manually
from torch_mutable_modules import convert_to_mutable_module, mutable_module
Usage
To convert an existing PyTorch module into a mutable module, use the convert_to_mutable_module
function:
converted_module = convert_to_mutable_module(
torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linear
You can also declare entire PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:
@mutable_module
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
More examples
Check out example.py to see more example usages of the convert_to_mutable_module
function and the @mutable_module
decorator.
Contributing
Please feel free to submit issues or pull requests!
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_mutable_modules-1.0.0.tar.gz
.
File metadata
- Download URL: torch_mutable_modules-1.0.0.tar.gz
- Upload date:
- Size: 4.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ef6391b72ea7aa1dbd0a693e7d9c49e1feedc02785f6df032556a9cd425d1630 |
|
MD5 | 0cb21c1d56eaab0fab981fd49a771043 |
|
BLAKE2b-256 | 64b8448c76d161ecde4d93e77fd62c522978150867caa873eb57d5fd80f8e8b8 |
File details
Details for the file torch_mutable_modules-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: torch_mutable_modules-1.0.0-py3-none-any.whl
- Upload date:
- Size: 4.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cd1c20d7b436c99ee055f3c27ecf2988fb18b13b2ef400a2ce1b930f105f4ca7 |
|
MD5 | d9cc8206ccc0ef99bb62ed62099244fe |
|
BLAKE2b-256 | 0a77e670dee17e63dafcaf6ef8cb885d8892cbfde1872d4a69b2c26725a9b857 |