Skip to main content

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the [Deep Rewiring](https://arxiv.org/abs/1711.05136) paper by Bellec et. al. ⚠️ Note: The implementation is not made by any of the authors. Please double-check everything before use.

Project description

DeepRewire

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al. ⚠️ Note: This implementation is not made by any of the authors. Please double-check everything before use.

Overview

DeepRewire provides tools to convert standard neural network parameters into a form that can be optimized using the DEEPR and SoftDEEPR optimizers. This allows for gaining network sparsity during training.

Installation

Install using pip install deep_rewire

Features

  • Conversion Functions: Convert networks to and from rewireable forms.
  • Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
  • Examples: Run provided examples to see the conversion and optimization in action.

Example Usage:

import torch
from deep_rewire import convert, reconvert, SoftDEEPR

# Define your model
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# Convert model parameters to rewireable form
rewireable_params, other_params = convert(model)

# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5) 
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable

# ... Standard training loop ...

# Convert back to standard form
reconvert(model)
# Model has the same parameters but is now (hopefully) sparse.

examples/softdeepr.py: SoftDEEPR Performance

Functionality

Conversion Functions

convert

deep_rewire.convert(module: nn.Module, handle_biases: str = "second_bias",
                           active_probability: float = None, keep_signs: bool = False)

Converts a PyTorch module into a rewireable form.

  • Parameters:
    • module (nn.Module): The model to convert.
    • handle_biases (str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.
    • active_probability (float): Probability for connections to be active right after conversion.
    • keep_signs (bool): Retain initial network signs and start with all connections active (for pretrained networks).

reconvert

deep_rewire.reconvert(module: nn.Module)

Converts a rewireable module back into its original form, making its sparsity visible.

  • Parameters:
    • module (nn.Module): The model to convert.

Optimizers

DEEPR

deep_rewire.DEEPR(params, nc=required, lr, l1, reset_val, temp)

The DEEPR algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.

  • nc (int): Fixed number of active connections.
  • lr (float): Learning rate.
  • l1 (float): L1 regularization term.
  • reset_val (float): Value for newly activated parameters.
  • temp (float): Temperature affecting noise magnitude.

SoftDEEPR

deep_rewire.SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)

The SoftDEEPR algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.

  • lr (float): Learning rate.

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements and fix my mistakes :).

License

This project is licensed under the MIT License.

Acknowledgements

  • Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their paper on Deep Rewiring.

For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_rewire-1.0.3-py3-none-any.whl (12.4 kB view details)

Uploaded Python 3

File details

Details for the file deep_rewire-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: deep_rewire-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 12.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.0

File hashes

Hashes for deep_rewire-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 f83665e79da19c13b94dc8d813b2835cfb7458aa0c3463118f475ddd03990ed1
MD5 e385478d40ed1b607ffd56c5b60ec97a
BLAKE2b-256 773908370bb3ff6c856f578c7cdd77ebc9e69b743bfc55ce9dd91e9c84104123

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page