Tools for tracking structured weight sparsity in PyTorch models.
Project description
torch-weighttracker
Tools for tracking structured weight sparsity, regularization signals, and bit-operation estimates in PyTorch models.
torch-weighttracker is useful when the question is not "what is in this one
tensor?" but "what is happening to the coupled channel, head, or feature unit
that is shared across several tensors?"
The package builds a structural view of a model, compiles tensorized reduction plans over that structure, and reuses those plans for training-time metrics and regularizers.
import torch
from torch import nn
from torch_weighttracker import WeightTracker
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
print(tracker.view_structures())
Installation
python -m pip install torch-weighttracker
Structured BOPs MAC accounting uses fvcore for baseline per-module MACs:
python -m pip install "torch-weighttracker[structured-bops]"
Why Use It?
PyTorch makes it easy to inspect one parameter tensor at a time. Structured compression often needs a different view:
- A channel can be coupled across convolutions, batch norms, linear layers, and residual paths.
- A transformer unit can mean an attention head, a head dimension, or a fused QKV slice rather than a simple row or column.
- A metric such as "active BOPs" depends on sparsity, module shape, MAC counts, and bitrates at the same time.
- A regularizer such as group lasso should penalize the coupled structural unit, not each weight tensor independently.
WeightTracker turns those coupled structures into canonical units, then lets
calculations operate over the canonical units with reusable tensor programs.
Use Cases
Current high-value use cases:
- Add structured group lasso to a training loss.
- Track active structured BOPs and compression rate during structured pruning, sparsity-aware training, or quantization-aware training (QAT).
- Inspect which modules participate in each channel, feature, head, or head-dim group.
- Build structural metrics that aggregate many weight tensors into one value per pruning unit.
Group Lasso
Structured group lasso regularizes coupled units together. Layers can be excluded per regularizer:
from torch_weighttracker.regularizers import RegularizerType
group_lasso = tracker.create_regularizer(
RegularizerType.GROUP_LASSO,
ignore=[model.classifier],
)
loss = task_loss + 1e-4 * group_lasso()
loss.backward()
Structured BOPs
Structured BOPs compares active bit operations against a dense 32-bit baseline:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.STRUCTURED_BOPS,
ignore=[torch.nn.BatchNorm2d],
log_compression_rate=True,
).track()
print(metrics["structured_bops"])
print(metrics["structured_bops_baseline"])
print(metrics["structured_bops_compression_rate"])
Architecture
The main API is WeightTracker. Internally it is split into a few layers:
- Dependency discovery:
WeightTrackerbuilds dependency groups from the model andexample_inputs, or accepts precomputed groups. - Canonical units:
canonical_units.pynormalizes raw dependency groups intoCanonicalUnitGroupobjects. These give channels, features, attention heads, and head dimensions a shared unit index. - Reduction plans:
reductions/andplans/compile module and unit mappings into segment and index operations that use PyTorch's efficient tensor computations. - Calculations:
calculations/defines named calculation specs such as per-unit L2 norm, active units, parameters per unit, active MACs, and bitrates. Calculations can depend on each other and cache constant results. - Consumers:
regularizers/andtrackers/request the calculations they need, optionally with anignorecontext for excluding modules from a specific metric or regularizer.
The result is a small public surface with a reusable internal graph:
model + example inputs
|
v
dependency groups -> canonical units -> reduction plans -> calculations
|
v
regularizers and trackers
Speed
Compared with a naive implementation, the current implementation gives the following speedups:
- Group lasso: 15.503x
- Naive: 4.6540s total, 232.698ms/step
- WeightTracker: 0.3002s total, 15.010ms/step
- Structured BOPs: 2.531x
- Naive: 0.6757s total, 33.783ms/step
- WeightTracker: 0.2669s total, 13.346ms/step
| Comparison | Speedup | Naive extra allocation | WeightTracker extra allocation |
|---|---|---|---|
| Group lasso | 15.421x | 197.0MiB | 197.0MiB |
| Structured BOPs | 2.582x | 1.7GiB | 195.9MiB |
Status
This package is pre-1.0. Public APIs may still change while the tracker, calculation, and regularizer surfaces settle.
Future Work
- Streamline definitions and method names across the codebase.
- Improve calculation caching so repeated computations are not performed twice.
- Improve compilation of computation plans for bigger speedups.
- Improve memory management within calculations.
- Write more comprehensive docstrings.
Future custom use cases will need a broader top-level WeightTracker API for
custom operations, custom layers, and generic group definitions.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_weighttracker-0.1.3.tar.gz.
File metadata
- Download URL: torch_weighttracker-0.1.3.tar.gz
- Upload date:
- Size: 179.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
32f9169713c1950359cd9915f5536af8ad2abd3d3194f1291a0208cdfdd8b517
|
|
| MD5 |
5e655264bba5bc271ca3028fef9da617
|
|
| BLAKE2b-256 |
e759b053f055e77004010d9b3d94e661384a837867e3ff5589e3322a9221eea4
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.1.3.tar.gz:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.1.3.tar.gz -
Subject digest:
32f9169713c1950359cd9915f5536af8ad2abd3d3194f1291a0208cdfdd8b517 - Sigstore transparency entry: 1553206018
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@77328a85a308f399abed94c09320020966d80576 -
Branch / Tag:
refs/tags/v0.1.3 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@77328a85a308f399abed94c09320020966d80576 -
Trigger Event:
push
-
Statement type:
File details
Details for the file torch_weighttracker-0.1.3-py3-none-any.whl.
File metadata
- Download URL: torch_weighttracker-0.1.3-py3-none-any.whl
- Upload date:
- Size: 93.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3329977a0a1b20613750abb59fdbf1dfad561ed7d2fb7f5738f0053deb0036d6
|
|
| MD5 |
341186d1de8f6762661a94a13e1caf41
|
|
| BLAKE2b-256 |
b44ed33ed1cdc637fb370349e836355574d0b68c46e0b40bc2bc8f681349d754
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.1.3-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.1.3-py3-none-any.whl -
Subject digest:
3329977a0a1b20613750abb59fdbf1dfad561ed7d2fb7f5738f0053deb0036d6 - Sigstore transparency entry: 1553206020
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@77328a85a308f399abed94c09320020966d80576 -
Branch / Tag:
refs/tags/v0.1.3 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@77328a85a308f399abed94c09320020966d80576 -
Trigger Event:
push
-
Statement type: