Skip to main content

torchhyper: A PyTorch library for modular hypernetworks

Project description

torchhyper: A PyTorch library for modular hypernetworks

Installation

Run the command below to install the package to be used in your Python environment.

pip install torchhyper

For further development and to run the examples, clone the repository and install the package in editable mode. Make sure to adapt CUDA version in setup.cfg to the one installed on your system.

# Create a new conda environment.
conda create --name torchhyper "python<=3.12"
conda activate torchhyper

# Clone the repository and install the package in editable mode.
git clone ttps://github.com/alisiahkoohi/torchhyper
cd torchhyper/
pip install -e .

Usage

The hypernetwork module HyperNetwork, found in torchhyper/models/architecture.py, should have the ability to adapt itself to the architecture of any given downstream network. This means that the only modifications required in any given script to train the hypernetwork for generating weights for a specific downstream network are outlined below. First, define the hypernetwork and its optimizer as follows:

from torchhyper.model import HyperNetwork

# Define your downstream network, e.g., a torch.nn.Module instance for a diffusion model.
downstream_network = YourDownstreamNetwork() # Any torch.nn.Module instance.

# No need to train the downstream network directly.
downstream_network.requires_grad_(False)

# Define the hypernetwork.
hypernetwork = HyperNetwork(
    input_dim,      # e.g., x[0, ...].numel() for input variable x (excluding batch dimension).
    [32, 64, 96], # A list of hidden layer sizes of the hypernetwork layers.
    downstream_net,
)

# Define the optimizer to train the hypernetwork.
optimizer = torch.optim.Adam(hypernetwork.parameters(), lr=1e-3)

Next, replace every instance of downstream_network forward evaluation downstream_network(x) for input tensor x with the following functional call to the downstream network using the predicted weights by the hypernetwork:

# Predicting the weight dictionary for the downstream network.
weight_dict = hypernetwork(x)

# Predicting the output of the downstream network using the predicted weights.
pred = torch.func.functional_call(
    downstream_network, # The downstream network.
    weight_dict,        # The predicted weight dictionary.
    x,                  # The input to the downstream network.
)

The downstream network output pred is now equivalent to downstream_network(x) using the predicted weights by the hypernetwork and can be similarly utilized for loss computation and backpropagation.

An simple test case for the gradient calculation with this approach can be found in tests/test_functional_call.py.

When calling other methods of the downstream network, e.g., downstream_network.sample(), that do not require gradient calculation, the predicted weights by the hypernetwork can be directly passed to the downstream network as follows before calling the method:

# Predicting the weight dictionary for the downstream network.
weight_dict = hypernetwork(x)

# Set the parameters of the downstream network to the computed ones.
for param, weight in zip(
        downstream_network.parameters(),
        weight_dict.values(),
):
    param.data = weight.data

# Calling the `sample` method of the downstream network (no need for gradient calculation).
downstream_network.sample()

Questions

Please contact alisk@rice.edu for questions.

Author

Ali Siahkoohi

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchhyper-0.1.2.tar.gz (29.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchhyper-0.1.2-py3-none-any.whl (32.3 kB view details)

Uploaded Python 3

File details

Details for the file torchhyper-0.1.2.tar.gz.

File metadata

  • Download URL: torchhyper-0.1.2.tar.gz
  • Upload date:
  • Size: 29.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.13

File hashes

Hashes for torchhyper-0.1.2.tar.gz
Algorithm Hash digest
SHA256 a474b56033f64aa8229583d781238b1e6e13db4399026de7a4f3fdcc1dbab260
MD5 cb4290c75c670dd5df3d92c61d631a2a
BLAKE2b-256 767a11b7356d1aeb6062aac52f5f2a3d2c960618f5bcb8717b404f40df10aad6

See more details on using hashes here.

File details

Details for the file torchhyper-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchhyper-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 32.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.13

File hashes

Hashes for torchhyper-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5941006adf6a060c9e8803090a14efb3771b3b80bf38cb370256401cf2daa2f0
MD5 f098413355ceaf8ece107f367e7bb672
BLAKE2b-256 edf0920f4c629176c0cf1bf0bfc01f84461b75d0f0b5dadb6f43a4551dbb4833

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page