Common PyTorch utilities for research projects on neural networks.
Project description
NN-Library
We in the BONSAI Lab do research on neural networks, among other things, that requires loading/training/reconfiguring neural network models. This library is a work-in-progress suite of in-house tools to address some pain-points we've encountered in our research workflow.
We make no guarantees about the stability or usability of this library, but we hope that it can be useful to others in the research community. If you have any questions or suggestions, please feel free to reach out to us or open an issue on the GitHub repository.
Installation
Using pip:
pip install bonsai-nn-library
Using uv (in a project):
uv add bonsai-nn-library
Or, as a drop-in replacement for pip:
uv pip install bonsai-nn-library
Usage (Python >= 3.10)
Say you want to use one of our "fancy layers" like a low-rank convolution. You can do so like this:
from torch import nn
from nn_lib.models.fancy_layers import LowRankConv2d
model = nn.Sequential(
LowRankConv2d(in_channels=3, out_channels=64, kernel_size=3, rank=8),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
nn.ReLU(),
nn.Flatten(),
nn.LazyLinear(10)
)
Useful thing #1: improved GraphModules.
PyTorch was not originally designed to handle explicit computation graphs, but it was added somewhat
later in the torch.fx module. Others might use tensorflow or jax for this, but we like PyTorch.
The torch.fx.GraphModule class is the built-in way to handle computation graphs in PyTorch, but it
lacks some features that we find useful. We have extended the GraphModule class in our
GraphModulePlus class, which inherits from GraphModule and adds some further functionality.
A motivating use-case is that we want to be able to "stitch" models together or extract out hidden
layer activity. This is a little tricky to get right using GraphModule alone, but we've added some
utilities like
GraphModulePlus.set_output(layer_name): use this to chop off the head of a model and make it output from a specific layer.GraphModulePlus.new_from_merge(...): use this to merge or "stitch" existing models together. Seedemos/demo_stitching.pyfor a worked out example.
We've also done some metaprogramming trickery so that if you import GraphModulePlus anywhere in
your code, it will automatically inject itself into the torch.fx module. The surprising but
convenient behavior is:
from torch import nn
from torch.fx import symbolic_trace
from nn_lib.models import GraphModulePlus
my_regular_torch_model = nn.Sequential(
nn.Conv2d(3, 64, 3),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.ReLU(),
nn.Flatten(),
nn.LazyLinear(10)
)
# Natively, symbolic_trace is expected to return a GraphModule, but we've injected GraphModulePlus
graphified_model = symbolic_trace(my_regular_torch_model)
assert isinstance(graphified_model, GraphModulePlus)
Useful thing #2: Fancy layers.
We have implemented a few "fancy" layers, available via nn_lib.models or
nn_lib.models.fancy_layers that we find useful in our research. These include:
Regressablelinear layers: aProtocolthat allows linear layers to be initialized by least squares regression. This is useful for initializing a linear layer to approximate a function learned by a different model.RegressableLinear: a regressable version ofnn.LinearLowRankLinear: a regressable linear layer with a low-rank factorization.ProcrustesLinear: a regressable linear layer constrained to rotation, with optional shift (bias) and optional scaling.- A conv2d version of each of the above.
Useful thing #3: MLFLOW utilities.
We use MLFlow to track our experiments. We have a few utilities in nn_lib.utils.mlfow that remove
a bit of boilerplate from our code.
Useful thing #4: Dataset utilities and torchvision wrappers.
We have implemented some LightningDataModule classes for a few standard vision datasets. See
nn_lib.datasets. We've also implemented a few simple helpers for downloading pretrained models
from torch hub using the torchvision API. See nn_lib/models/__init__.py
Useful thing #5: NTK utilities.
See nn_lib.analysis.ntk for some neural tangent kernel utilities.
Forthcoming/Planned features
- More fancy layers
- Vector Quantization utilities (but see
nn_lib.models.sparse_auto_encoderwhich has some already) - Further analysis utilities especially focused on calculating neural similarity measures.
Obsolete/deprecated features
- lightning training and overly-complex CLI utilities. Some straggler files might still need to be cleaned up.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bonsai_nn_library-0.5.4.tar.gz.
File metadata
- Download URL: bonsai_nn_library-0.5.4.tar.gz
- Upload date:
- Size: 70.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
11b32bdf25a4a7b8627cae3752334a67d13db05be9957ba16eccf943fe497da2
|
|
| MD5 |
d3ed1b4155723b0e502435a26aa8225d
|
|
| BLAKE2b-256 |
8a0d098300c65347aada570b946bf4bae4aa975930d0e701a956febee322c40c
|
File details
Details for the file bonsai_nn_library-0.5.4-py3-none-any.whl.
File metadata
- Download URL: bonsai_nn_library-0.5.4-py3-none-any.whl
- Upload date:
- Size: 65.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b4826cbd0885592a7423b1985c162c7bd1c2f2324a6e425af086e0234cdca099
|
|
| MD5 |
e7a17d098f6dd75028988c4038732e38
|
|
| BLAKE2b-256 |
74ca66cff02d850364789624f2f189f126b07365c59aff09d7c2a22ea2821a80
|