DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the [Deep Rewiring](https://arxiv.org/abs/1711.05136) paper by Bellec et. al. ⚠️ Note: The implementation is not made by any of the authors. Please double-check everything before use.
Project description
DeepRewire
DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al. ⚠️ Note: This implementation is not made by any of the authors. Please double-check everything before use.
Overview
DeepRewire provides tools to convert standard neural network parameters into a form that can be optimized using the DEEPR and SoftDEEPR optimizers. This allows for gaining network sparsity during training.
Installation
Install using pip install deep_rewire
Features
- Conversion Functions: Convert networks to and from rewireable forms.
- Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
- Examples: Run provided examples to see the conversion and optimization in action.
Example Usage:
import torch
from deep_rewire import convert, reconvert, SoftDEEPR
# Define your model
model = torch.nn.Sequential(
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
# Convert model parameters to rewireable form
rewireable_params, other_params = convert(model)
# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5)
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable
# ... Standard training loop ...
# Convert back to standard form
reconvert(model)
# Model has the same parameters but is now (hopefully) sparse.
examples/softdeepr.py:
Functionality
Conversion Functions
convert
deep_rewire.convert(module: nn.Module, handle_biases: str = "second_bias",
active_probability: float = None, keep_signs: bool = False)
Converts a PyTorch module into a rewireable form.
- Parameters:
module(nn.Module): The model to convert.handle_biases(str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.active_probability(float): Probability for connections to be active right after conversion.keep_signs(bool): Retain initial network signs and start with all connections active (for pretrained networks).
reconvert
deep_rewire.reconvert(module: nn.Module)
Converts a rewireable module back into its original form, making its sparsity visible.
- Parameters:
module(nn.Module): The model to convert.
Optimizers
DEEPR
deep_rewire.DEEPR(params, nc=required, lr, l1, reset_val, temp)
The DEEPR algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.
nc(int): Fixed number of active connections.lr(float): Learning rate.l1(float): L1 regularization term.reset_val(float): Value for newly activated parameters.temp(float): Temperature affecting noise magnitude.
SoftDEEPR
deep_rewire.SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)
The SoftDEEPR algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.
-
lr(float): Learning rate. -
l1(float): L1 regularization term. -
temp(float): Temperature affecting noise magnitude. -
min_weight(float): Minimum value for inactive parameters.
Contributing
Contributions are welcome! Please open an issue or submit a pull request for any improvements and fix my mistakes :).
License
This project is licensed under the MIT License.
Acknowledgements
- Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their paper on Deep Rewiring.
For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file deep_rewire-1.0.2-py3-none-any.whl.
File metadata
- Download URL: deep_rewire-1.0.2-py3-none-any.whl
- Upload date:
- Size: 12.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b3bd7512d6fc8e6b1c43c9fb4a1cfb3d7d1e299e942331ca539953f71a1f989b
|
|
| MD5 |
799330892fb85648c95b1aca20539c4c
|
|
| BLAKE2b-256 |
fae2e8350570571f8865b7ed9449b127a1215e2f47c2bf4f67a8eef509a37d7d
|