Skip to main content

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the [Deep Rewiring](https://arxiv.org/abs/1711.05136) paper by Bellec et. al. ⚠️ Note: The implementation is not made by any of the authors. Please double-check everything before use.

Project description

DeepRewire

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al. ⚠️ Note: This implementation is not made by any of the authors. Please double-check everything before use.

Overview

DeepRewire provides tools to convert standard neural network parameters into a form that can be optimized using the DEEPR and SoftDEEPR optimizers. This allows for gaining network sparsity during training.

Installation

Install using pip install deep_rewire

Features

  • Conversion Functions: Convert networks to and from rewireable forms.
  • Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
  • Examples: Run provided examples to see the conversion and optimization in action.

Example Usage:

import torch
from deep_rewire import convert, reconvert, SoftDEEPR

# Define your model
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# Convert model parameters to rewireable form
rewireable_params, other_params = convert(model)

# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5) 
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable

# ... Standard training loop ...

# Convert back to standard form
reconvert(model)
# Model has the same parameters but is now (hopefully) sparse.

examples/softdeepr.py: SoftDEEPR Performance

Functionality

Conversion Functions

convert

deep_rewire.convert(module: nn.Module, handle_biases: str = "second_bias",
                           active_probability: float = None, keep_signs: bool = False)

Converts a PyTorch module into a rewireable form.

  • Parameters:
    • module (nn.Module): The model to convert.
    • handle_biases (str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.
    • active_probability (float): Probability for connections to be active right after conversion.
    • keep_signs (bool): Retain initial network signs and start with all connections active (for pretrained networks).

reconvert

deep_rewire.reconvert(module: nn.Module)

Converts a rewireable module back into its original form, making its sparsity visible.

  • Parameters:
    • module (nn.Module): The model to convert.

Optimizers

DEEPR

deep_rewire.DEEPR(params, nc=required, lr, l1, reset_val, temp)

The DEEPR algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.

  • nc (int): Fixed number of active connections.
  • lr (float): Learning rate.
  • l1 (float): L1 regularization term.
  • reset_val (float): Value for newly activated parameters.
  • temp (float): Temperature affecting noise magnitude.

SoftDEEPR

deep_rewire.SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)

The SoftDEEPR algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.

  • lr (float): Learning rate.

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

SoftDEEPRWrapper

deep_rewire.SoftDEEPRWrapper(params, base_optim, l1=1e-5, temp=None, min_weight=None, **optim_kwargs)

Uses the SoftDEEPR algorithm regarding keeping the connections sparse but updates the parameters using any chosen torch optimizer (SGD, Adam..).

  • base_optim (torch.optim.Optimizer): The basic optimizer to use for updating the parameters

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

  • **optim_kwargs: Arguments for the base optimizer

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements and fix my mistakes :).

License

This project is licensed under the MIT License.

Acknowledgements

  • Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their paper on Deep Rewiring.

For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_rewire-1.0.5-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file deep_rewire-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: deep_rewire-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.0

File hashes

Hashes for deep_rewire-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 c52d996c3eddb966005ec6c1a9e1d04b03c149cb4f4017194589e1026ab40157
MD5 85efdeb31612d4047cd9cba9ccd6543c
BLAKE2b-256 132c3defaa4122a398675030ec8e25c3fb0f1a3ddff91769a3fcd745ade78241

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page