Skip to main content

Easy hypernetworks in Pytorch and Flax

Project description

hyper-nn -- Easy Hypernetworks in Pytorch and Flax

PyPi version

Note: This library is experimental and currently under development - the flax implementations in particular are far from perfect and can be improved. If you have any suggestions on how to improve this library, please open a github issue or feel free to reach out directly!

hyper-nn gives users with the ability to create easily customizable Hypernetworks for almost any generic torch.nn.Module from Pytorch and flax.linen.Module from Flax. Our Hypernetwork objects are also torch.nn.Modules and flax.linen.Modules, allowing for easy integration with existing systems. For Pytorch, we make use of the amazing library functorch

Generating Policy Weights for Lunar Lander



Dynamic Weights for each character in a name generator


Install

hyper-nn tested on python 3.8+

Installing with pip

$ pip install hyper-nn

Installing from source

$ git clone git@github.com:shyamsn97/hyper-nn.git
$ cd hyper-nn
$ python setup.py install

For gpu functionality with Jax, you will need to follow the instructions here

Hypernetworks, simply put, are neural networks that generate parameters for another neural network. They can be incredibly powerful, being able to represent large networks while using only a fraction of their parameters.

hyper-nn represents Hypernetworks with two key components:

  • EmbeddingModule that holds information about layers(s) in the target network, or more generally a chunk of the target networks weights
  • Weight Generator, which takes in the embedding and outputs a parameter vector for the target network

Hypernetworks generally come in two variants, static or dynamic. Static Hypernetworks have a fixed or learned embedding and weight generator that outputs the target networks’ weights deterministically. Dynamic Hypernetworks instead receive inputs and use them to generate dynamic weights.


Quick Usage

for detailed examples see notebooks

The main classes to use are TorchHyperNetwork and JaxHyperNetwork and those that inherit them. Instead of constructing them directly, use the from_target method, shown below. After this you can use the hypernetwork exactly like any other nn.Module!

Pytorch

import torch.nn as nn

# static hypernetwork
from hypernn.torch.hypernet import TorchHyperNetwork

# any module
target_network = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32)
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = TorchHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS
)

# now we can use the hypernetwork like any other nn.Module
inp = torch.zeros((1, 32))

# by default we only output what we'd expect from the target network
output = hypernetwork(inp=[inp])

# return aux_output
output, generated_params, aux_output = hypernetwork(inp=[inp], has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.generate_params(inp=[inp])
output = hypernetwork(inp=[inp], generated_params=generated_params)

Jax

import flax.linen as nn
import jax.numpy as jnp
from jax import random

# static hypernetwork
from hypernn.jax.dynamic_hypernet import JaxHyperNetwork

# any module
target_network = nn.Sequential(
    [
        nn.Dense(64),
        nn.relu,
        nn.Dense(32)
    ]
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

hypernetwork = JaxHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    inputs=jnp.zeros((1, 32)) # jax needs this to initialize target weights
)

# now we can use the hypernetwork like any other nn.Module
inp = jnp.zeros((1, 32))
key = random.PRNGKey(0)
hypernetwork_params = hypernetwork.init(key, inp=[inp]) # flax needs to initialize hypernetwork parameters first

# by default we only output what we'd expect from the target network
output = hypernetwork.apply(hypernetwork_params, inp=[inp])

# return aux_output
output, generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], has_aux=True)

# generate params separately
generated_params, aux_output = hypernetwork.apply(hypernetwork_params, inp=[inp], method=hypernetwork.generate_params)

output = hypernetwork.apply(hypernetwork_params, inp=[inp], generated_params=generated_params)

Advanced: Using vmap for batching operations

This is useful when dealing with dynamic hypernetworks that generate different params depending on inputs.

Pytorch

import torch.nn as nn
from functorch import vmap

# dynamic hypernetwork
from hypernn.torch.dynamic_hypernet import TorchDynamicHyperNetwork

# any module
target_network = nn.Sequential(
    nn.Linear(8, 256),
    nn.ReLU(),
    nn.Linear(256, 32)
)

EMBEDDING_DIM = 4
NUM_EMBEDDINGS = 32

# conditioned on input to generate param vector
hypernetwork = TorchDynamicHyperNetwork.from_target(
    target_network = target_network,
    embedding_dim = EMBEDDING_DIM,
    num_embeddings = NUM_EMBEDDINGS,
    input_dim = 8
)

# batch of 10 inputs
inp = torch.randn((10, 1, 8))

# use with a for loop
outputs = []
for i in range(10):
    outputs.append(hypernetwork(inp=[inp[i]]))
outputs = torch.stack(outputs)
assert outputs.size() == (10, 1, 32)

# using vmap
outputs = vmap(hypernetwork)([inp])
assert outputs.size() == (10, 1, 32)

Detailed Explanation

EmbeddingModule

The EmbeddingModule is used to store information about layers(s) in the target network, or more generally a chunk of the target networks weights. The standard representation is with a matrix of size num_embeddings x embedding_dim. hyper-nn uses torch's nn.Embedding and flax's nn.Embed classes to represent this.

WeightGenerator

WeightGenerator takes in the embedding matrix from EmbeddingModule and outputs a parameter vector of size num_target_parameters, equal to the total number of parameters in the target network. To ensure that the output is equal to num_target_parameters, the WeightGenerator outputs a matrix of size num_embeddings x weight_chunk_dim, where weight_chunk_dim = num_target_parameters // num_embeddings, and then flattens it.

Hypernetwork

the Hypernetwork by default uses a setup function to initialize the embedding_module and weight_generator from either user provided modules or the functions: make_embedding_module, make_weight_generator. This makes it really easy to customize and use your own modules instead of the basic versions provided. generate_params is used to generate the target parameters and forward combines the generated parameters with the target network to compute a forward pass

Instead of creating the Hypernetwork class directly, use from_target instead

Base class: code

class HyperNetwork(metaclass=abc.ABCMeta):
    embedding_module = None
    weight_generator = None

    def setup(self) -> None:
        if self.embedding_module is None:
            self.embedding_module = self.make_embedding_module()

        if self.weight_generator is None:
            self.weight_generator = self.make_weight_generator()

    @abc.abstractmethod
    def make_embedding_module(self):
        """
        Makes an embedding module to be used

        Returns:
            a torch.nn.Module or flax.linen.Module that can be used to return an embedding matrix to be used to generate weights
        """

    @abc.abstractmethod
    def make_weight_generator(self):
        """
        Makes an embedding module to be used

        Returns:
            a torch.nn.Module or flax.linen.Module that can be used to return an embedding matrix to be used to generate weights
        """

    @classmethod
    @abc.abstractmethod
    def count_params(
        cls,
        target,
        target_input_shape: Optional[Any] = None,
    ):
        """
        Counts parameters of target nn.Module

        Args:
            target (Union[torch.nn.Module, flax.linen.Module]): _description_
            target_input_shape (Optional[Any], optional): _description_. Defaults to None.
        """

    @classmethod
    @abc.abstractmethod
    def from_target(cls, target, *args, **kwargs) -> HyperNetwork:
        """
        creates hypernetwork from target

        Args:
            cls (_type_): _description_
        """

    @abc.abstractmethod
    def generate_params(self, inp: Optional[Any] = None, *args, **kwargs) -> Tuple[Any, Dict[str, Any]]:
        """
        Generate a vector of parameters for target network

        Args:
            inp (Optional[Any], optional): input, may be useful when creating dynamic hypernetworks

        Returns:
            Any: vector of parameters for target network and a dictionary of extra info
        """

    @abc.abstractmethod
    def forward(
        self,
        inp: Iterable[Any] = [],
        generated_params=None,
        has_aux: bool = True,
        *args,
        **kwargs,
    ):
        """
        Computes a forward pass with generated parameters or with parameters that are passed in

        Args:
            inp (Any): input from system
            generated_params (Optional[Union[torch.tensor, jnp.array]], optional): Generated params. Defaults to None.
            has_aux (bool): flag to indicate whether to return auxiliary info
        Returns:
            returns output and generated params and auxiliary info if has_aux is provided
        """

Citing hyper-nn

If you use this software in your academic work please cite

@misc{sudhakaran2022,
  author = {Sudhakaran, Shyam Sudhakaran},
  title = {hyper-nn},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/shyamsn97/hyper-nn}}
}

Projects used in hyper-nn

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hyper-nn-0.1.2.tar.gz (9.9 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

hyper_nn-0.1.2-py3.9.egg (26.0 kB view details)

Uploaded Egg

hyper_nn-0.1.2-py3-none-any.whl (13.7 kB view details)

Uploaded Python 3

File details

Details for the file hyper-nn-0.1.2.tar.gz.

File metadata

  • Download URL: hyper-nn-0.1.2.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for hyper-nn-0.1.2.tar.gz
Algorithm Hash digest
SHA256 ff2afde26054e461a8baf9a10a22566afaf88e04c550c2f6b6157d1fa5fb414c
MD5 3a9d139e22fa409d1a8e3a2ad583031f
BLAKE2b-256 9d2dc5009a46553b10b3cc008fbef5b9f50140800c25537124e0c0c04ca9fe12

See more details on using hashes here.

File details

Details for the file hyper_nn-0.1.2-py3.9.egg.

File metadata

  • Download URL: hyper_nn-0.1.2-py3.9.egg
  • Upload date:
  • Size: 26.0 kB
  • Tags: Egg
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for hyper_nn-0.1.2-py3.9.egg
Algorithm Hash digest
SHA256 1909ba4226da5849b2add0470e367e16aa3d62957bb84ba56ab6140c291f11ca
MD5 0c1525e62dae82046149e3908088175d
BLAKE2b-256 807165731cc9f8d77b505679c5877b7976b948626f181eb0c7e8b2ba2b66f633

See more details on using hashes here.

File details

Details for the file hyper_nn-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: hyper_nn-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 13.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for hyper_nn-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1a072a31e0dee1105286128efb493151d38f4b4882a01c61bc0334d729ed5011
MD5 86994ab303202d625c40f4b13f9db04d
BLAKE2b-256 1b4876dc9c3d93879449b0823d66810fbf4774be0134ce5cf2b07db56f5b9ca9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page