Skip to main content

Tools for tracking structured weight sparsity in PyTorch models.

Project description

torch-weighttracker

Package for tracking structured weight sparsity, regularization signals, and bit-operation estimates in torch modules.

The API is centered on WeightTracker:

import torch
from torch import nn

from torch_weighttracker import WeightTracker

model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
print(tracker.view_structures())

Installation

python -m pip install torch-weighttracker

Structured BOPs MAC accounting uses fvcore for baseline per-module MACs:

python -m pip install "torch-weighttracker[structured-bops]"

Tensorized cross weight operations

Weighttracker builds an interface for doing cross weight tensor operations on models efficiently. Using the "Computation plans" and "Calculation" classes, we compile a set of torch modules which execute tensors and mapping operations on training device with the "minimal" set of repeated operations.

Use case

Weighttrackers primary use case for now is for calculating structural depedency based loss terms & metric evaluations, such as structured sparsity & structured compression rates, and group lasso. However, the code has been made such that it can be used for any weight traversering operations with some modifications.

Group lasso

Structured group lasso regularizes coupled units together. Layers can be excluded per regularizer:

from torch_weighttracker.regularizers import RegularizerType

group_lasso = tracker.create_regularizer(
    RegularizerType.GROUP_LASSO,
    ignore=[model.classifier],
)

loss = task_loss + 1e-4 * group_lasso()
loss.backward()

Structured BOPs

Structured BOPs compares active bit operations against a dense 32-bit baseline:

import torch

from torch_weighttracker.trackers import TrackerType

metrics = tracker.create_tracker(
    TrackerType.STRUCTURED_BOPS,
    ignore=[torch.nn.BatchNorm2d],
    log_compression_rate=True,
).track()

print(metrics["structured_bops"])
print(metrics["structured_bops_baseline"])
print(metrics["structured_bops_compression_rate"])

Speed

Comparing with a naive implementation we get the following speed ups:

  • Group lasso: 15.503x
    • Naive: 4.6540s total, 232.698ms/step
    • Weighttracker: 0.3002s total, 15.010ms/step
  • Structured BOPs: 2.531x
    • Naive: 0.6757s total, 33.783ms/step
    • Weighttracker: 0.2669s total, 13.346ms/step
Comparison Speedup Naive extra alloc Weighttracker extra alloc
Group lasso 15.421x 197.0MiB 197.0MiB
Structured BOPs 2.582x 1.7GiB 195.9MiB

Status

This package is pre-1.0. Public APIs may still change while the tracker, calculation, and regularizer surfaces settle.

Future work:

  1. Streamlining defintions and methods across the code for a unified and more compressed and perhabs more understandable API.
  2. Implement Calculation caching, such that computations are not computed twice
  3. Improve compilations of computation plans
  4. Improve memory management within calculations
  5. Write more comprehensive docstrings

For future use cases, an update of the toplevel API WeightTracker is needed, including the ability to input custom operations, custom layers, generic group defintions etc.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_weighttracker-0.1.2.tar.gz (178.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_weighttracker-0.1.2-py3-none-any.whl (92.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_weighttracker-0.1.2.tar.gz.

File metadata

  • Download URL: torch_weighttracker-0.1.2.tar.gz
  • Upload date:
  • Size: 178.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torch_weighttracker-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5c4cff95b514d64ee44a13a5d2bceb831d912577fb9b7d7e00d2d9e81aa72623
MD5 72da107d95d3f51ca3114ab29edf787b
BLAKE2b-256 63dfefc350ed2524b235fbf894ca10aa45ba4506841d4ca0f5994b6ba2170fa3

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_weighttracker-0.1.2.tar.gz:

Publisher: publish-to-pypi.yml on dadyownes15/torch-weighttracker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_weighttracker-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_weighttracker-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a20e5fd0bc50f6e924872ccf300cd3035f421abae54e09301b27de4a3fe1d447
MD5 a3b32e63dcb95329011b328dd842bdba
BLAKE2b-256 b630a2c4d586755df1f3cf797a152b231817780cef2cd971cb217957d297b005

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_weighttracker-0.1.2-py3-none-any.whl:

Publisher: publish-to-pypi.yml on dadyownes15/torch-weighttracker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page