Skip to main content

Sparsity enforcement and adaptive sparsity update utilities for neural networks

Project description

Sparling

Sparsity enforcement and adaptive sparsity update utilities for neural networks.

Installation

pip install sparling

Usage

Construct a sparsity layer with batch normalization (recommended):

from sparling import SparseLayerWithBatchNorm

sparse_layer = SparseLayerWithBatchNorm(
    underlying_sparsity_spec=dict(type="EnforceSparsityPerChannel2D"),
    starting_sparsity=0.9,
    channels=128,
    affine=True,
    input_dimensions=2,  # 2 for (N,C,H,W), 1 for (N,C,L)
)

# Training: calibrate thresholds on your data
sparse_layer.train()
for batch in training_batches:
    out = sparse_layer(batch)  # thresholds update via momentum

# Inference: thresholds are frozen
sparse_layer.eval()
out = sparse_layer(x)  # ~90% of values are zero

Wrap the optimizer with a sparsity update optimizer to adaptively increase sparsity when the model exceeds an accuracy threshold:

from sparling import LinearThresholdAdaptiveSUO

suo = LinearThresholdAdaptiveSUO(
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
    initial_threshold=0.9,
    minimal_threshold=0.8,
    maximal_threshold=0.95,
    threshold_decrease_per_iter=1e-5,
    minimal_update_frequency=100,
    information_multiplier=0.5,
)

# In your training loop:
suo.zero_grad()
loss.backward()
suo.step()
suo.update_sparsity(model, step=step, acc_info=dict(acc=accuracy))

The model should have a setter on a property called sparsity_value that updates the sparsity of all sparsity layers in the model. For example:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.sparsity_value = 0.9
        self.sparse_layer1 = SparseLayerWithBatchNorm(...)
        self.sparse_layer2 = SparseLayerWithBatchNorm(...)
    @property
    def sparsity_value(self):
        return self._sparsity_value
    @sparsity_value.setter
    def sparsity_value(self, value):
        self._sparsity_value = value
        self.sparse_layer1.sparsity = value
        self.sparse_layer2.sparsity = value

Overview

Sparling provides a collection of torch.nn.Module-based sparsity layers and adaptive sparsity update optimizers for training sparse neural networks.

Sparsity layers

All sparsity layers extend the Sparsity base class (itself an nn.Module). The sparsity property can be updated at any time, and subclasses react via notify_sparsity().

Class Description
EnforceSparsityPerChannel Per-channel threshold with momentum
EnforceSparsityPerChannelAccumulated Accumulated batches before threshold update
EnforceSparsityPerChannel2D 2-D (N,C,H,W) wrapper
EnforceSparsityPerChannel1D 1-D (N,C,L) wrapper
EnforceSparsityUniversally Single global threshold

There's also sparsity combinators:

Class Description
SparseLayerWithBatchNorm BatchNorm + sparsity wrapper. Absolutely necessary for performance.
ParallelSparsityLayers Applies different sparsity layers to channel subsets

Use the sparsity_types() registry to construct layers from config dicts via dconstruct.construct.

Sparsity update optimizers

Class Description
NoopSUO Does nothing (baseline)
LinearThresholdAdaptiveSUO Accuracy-threshold-driven adaptive sparsity reduction

Use the suo_types() registry for construction.

Development

pip install -r requirements.txt
pip install -e .
python -m pytest tests
python -m pylint sparling tests

Baselines

Simple activation-based layers that do not enforce a target sparsity level, useful as baselines or when sparsity is controlled externally (e.g. via an L1 or KL loss).

Class Description
NoSparsity Identity pass-through
SparsityForL1 ReLU activation
ChangingSparsityForL1 ReLU with density-scaled motif loss
SparsityForKL Sigmoid activation
NoiseRatherThanSparsity Gaussian noise bottleneck

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparling-0.3.0.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sparling-0.3.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file sparling-0.3.0.tar.gz.

File metadata

  • Download URL: sparling-0.3.0.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.7.0 requests/2.25.1 setuptools/68.0.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.7.10

File hashes

Hashes for sparling-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1ca4ae79e85e759e835e1a32d61ad0d103f2e072f59f22dc1a012df9e71ccfcd
MD5 69998b536dac284c5d7d292b4e769c4e
BLAKE2b-256 16090de5d12318ea5a6f50a6e5c2ff6f290d2cbd0acaddcd4583f23c5f981167

See more details on using hashes here.

File details

Details for the file sparling-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: sparling-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.7.0 requests/2.25.1 setuptools/68.0.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.7.10

File hashes

Hashes for sparling-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 71bcc33014bd3d64d861977e52d8c6aa0b6d2aad7258f09ad4fc6a7d944770a8
MD5 72aa82fabc2077316519a0e750a49368
BLAKE2b-256 f644f73569dc58b84e70b20b1773bb67296bb8a798c38e7d3edff0cc9a6e6ca0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page