Skip to main content

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the [Deep Rewiring](https://arxiv.org/abs/1711.05136) paper by Bellec et. al. ⚠️ Note: The implementation is not made by any of the authors. Please double-check everything before use.

Project description

DeepRewire

DeepRewire is a PyTorch-based project designed to simplify the creation and optimization of sparse neural networks with the concepts from the Deep Rewiring paper by Bellec et. al. ⚠️ Note: The implementation is not made by any of the authors. Please double-check everything before use.

Overview

DeepRewire provides tools to convert standard neural network parameters into a rewireable form that can be optimized using the DEEPR and SoftDEEPR algorithms. This allows for gaining network sparsity during training.

Features

  • Conversion Functions: Convert networks to and from rewireable forms.
  • Optimizers: Use DEEPR and SoftDEEPR to optimize sparse networks.
  • Examples: Run provided examples to see the conversion and optimization in action.

Example Usage:

import torch
from deep_rewire import convert_to_deep_rewireable, convert_from_deep_rewireable, SoftDEEPR

# Define your model
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# Convert model parameters to rewireable form
rewireable_params, other_params = convert_to_deep_rewireable(model)

# Define optimizers
optim1 = SoftDEEPR(rewireable_params, lr=0.05, l1=1e-5) 
optim2 = torch.optim.SGD(other_params, lr=0.05) # Optional, for parameters that are not rewireable

# ... Standard training loop ...

# Convert back to standard form
convert_from_deep_rewireable(model)
# Model has the same parameters but is now sparse.

SoftDEEPR Performance

Functionality

Conversion Functions

convert_to_deep_rewireable

convert_to_deep_rewireable(module: nn.Module, handle_biases: str = "second_bias",
                           active_probability: float = None, keep_signs: bool = False)

Converts a PyTorch module into a rewireable form.https://github.com/guillaumeBellec/deep_rewiring

  • Parameters:
    • module (nn.Module): The model to convert.
    • handle_biases (str): Strategy to handle biases. Options are 'ignore', 'as_connections', and 'second_bias'.
    • active_probability (float): Probability for connections to be active right after conversion.
    • keep_signs (bool): Retain initial network signs for pretrained networks.

convert_from_deep_rewireable

convert_from_deep_rewireable(module: nn.Module)

Converts a rewireable module back into its original form, making its sparsity visible.

  • Parameters:
    • module (nn.Module): The model to convert.

Optimizers

DEEPR

DEEPR(params, nc=required, lr, l1, reset_val, temp)

The DEEPR algorithm keeps a fixed number of connections, which when becoming inactive, new connections are activated randomly to keep the same connectivity.

  • nc (int): Fixed number of active connections.
  • lr (float): Learning rate.
  • l1 (float): L1 regularization term.
  • reset_val (float): Value for newly activated parameters.
  • temp (float): Temperature affecting noise magnitude.

SoftDEEPR

SoftDEEPR(params, lr=0.05, l1=1e-5, temp=None, min_weight=None)

The SoftDEEPR algorithm has no fixed amount of connections, but also adds noise to its inactive connections to randomly activate them.

  • lr (float): Learning rate.

  • l1 (float): L1 regularization term.

  • temp (float): Temperature affecting noise magnitude.

  • min_weight (float): Minimum value for inactive parameters.

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements.

License

This project is licensed under the MIT License.

Acknowledgements

  • Guillaume Bellec, David Kappel, Wolfgang Maass, Robert Legenstein for their work on Deep Rewiring.

For more details, refer to their Deep Rewiring paper or their TensorFlow tutorial.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deep_rewire-0.1.0.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deep_rewire-0.1.0-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file deep_rewire-0.1.0.tar.gz.

File metadata

  • Download URL: deep_rewire-0.1.0.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.0

File hashes

Hashes for deep_rewire-0.1.0.tar.gz
Algorithm Hash digest
SHA256 eeea5b6f0bbc71e356274b6c56862521e5f1619529f3c4b177346da612fb3fd5
MD5 faae2af4ae0b656808da1a3df619b9aa
BLAKE2b-256 a30179c8a9307c1296814d57192a16df485a35d22baead1cb4f0a9390bd8d8f4

See more details on using hashes here.

File details

Details for the file deep_rewire-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: deep_rewire-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.0

File hashes

Hashes for deep_rewire-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 338a43cb517d33bd8e3ad6d84d2dbd87aca30f0e15d9999ff4c64416fe875666
MD5 837712c9683278a827d9e3c7a14ff2f9
BLAKE2b-256 6cd4da07e2821ba879014484e54f8696fb31e7fe7b228d80e37f3b08ec9e0be7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page