Tools for tracking structured weight sparsity in PyTorch models.
Project description
torch-weighttracker
Track, regularize, and prune structured units in PyTorch models.
torch-weighttracker gives you a model-level view of sparsity: channels,
features, attention heads, head dimensions, and fused QKV slices are grouped into
canonical units, so metrics, pruning, and regularizers operate on the structure
you would actually compress.
import torch
import timm
from torch_weighttracker import WeightTracker
from torch_weighttracker.integrations.timm import infer_vit_num_heads
model = timm.create_model("vit_base_patch16_224", pretrained=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
example_inputs = torch.randn(1, 3, 224, 224)
tracker = WeightTracker(
model,
example_inputs=example_inputs,
num_heads=infer_vit_num_heads(model),
prune_num_heads=True,
)
tracker.create_tracker("structured_bops", log_total_bops=True)
group_lasso = tracker.create_regularizer("group_lasso")
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
task_loss = criterion(outputs, targets)
loss = task_loss + 1e-4 * group_lasso()
loss.backward()
optimizer.step()
metrics = tracker.track()
tracker.prune_zero_units()
Installation
python -m pip install torch-weighttracker
BOPs MAC accounting uses fvcore for baseline per-module MACs:
python -m pip install "torch-weighttracker[structured-bops]"
Why Use It?
PyTorch makes it easy to inspect one parameter tensor at a time. Structured compression often needs a different view:
- A channel can be coupled across convolutions, batch norms, linear layers, and residual paths.
- A transformer unit can mean an attention head, a head dimension, or a fused QKV slice rather than a simple row or column.
- A metric such as "active BOPs" depends on sparsity, module shape, MAC counts, and bitrates at the same time.
- A regularizer such as group lasso should penalize the coupled structural unit, not each weight tensor independently.
WeightTracker turns those coupled structures into canonical units, then lets
calculations operate over the canonical units with reusable tensor programs.
Use Cases
Current use cases:
- Add structured group lasso to a training loss.
- Track active structured or unstructured BOPs and compression rate during pruning, sparsity-aware training, or quantization-aware training (QAT).
- Inspect which modules participate in each channel, feature, head, or head-dim group.
- Build structural metrics that aggregate many weight tensors into one value per pruning unit.
- Physically prune zeroed canonical units, including attention heads, after sparsity-aware training.
Current Pruning Notes
WeightTracker can inspect zeroed canonical units with view_zero_units() /
view_zero_structures() and physically remove them with prune_zero_units() /
prune_zero_structures(). You can also remove one canonical unit directly with
prune_unit(group_id, unit_id).
Zero detection can ignore module instances or module types, matching tracker filter semantics. The ignore filter only decides whether a structure is zero; if that structure is pruned, the coupled Torch-Pruning group is still applied:
zero_view = tracker.view_zero_structures(ignore=[torch.nn.BatchNorm2d])
tracker.prune_zero_structures(ignore=[torch.nn.BatchNorm2d])
Physical pruning changes module shapes and rebuilds the dependency state. Any
registered trackers or regularizers are cleared after prune_unit() or
prune_zero_units() / prune_zero_structures(), so recreate them before
collecting metrics or losses from the pruned model:
metrics_before = tracker.create_tracker(
"structured_bops",
log_total_bops=True,
).track()
tracker.prune_zero_units()
metrics_after = tracker.create_tracker(
"structured_bops",
log_total_bops=True,
).track()
Fake pruning remains useful during training because it zeros the selected canonical unit while keeping the module shapes intact:
tracker.create_tracker("structured_bops", log_total_bops=True)
metrics_before = tracker.track()
tracker.fake_prune_unit(group_id=3, unit_id=0)
tracker.fake_prune_unit(group_id=3, unit_id=2)
metrics_after = tracker.track()
timm ViT Attention Heads
The torch_weighttracker.integrations.timm helpers make timm ViT attention
blocks visible as head-level pruning groups. infer_vit_num_heads(model) maps
each fused Attention.qkv projection to its current head count, and
sync_vit_attention_metadata updates timm attention metadata after physical
head pruning.
import timm
import torch
from torch_weighttracker import WeightTracker
from torch_weighttracker.integrations.timm import (
infer_vit_num_heads,
sync_vit_attention_metadata,
)
example_inputs = torch.rand(1, 3, 224, 224)
model = timm.create_model(
"vit_base_patch16_224",
pretrained=False,
num_classes=10,
)
tracker = WeightTracker(
model,
example_inputs=example_inputs,
num_heads=infer_vit_num_heads(model),
prune_num_heads=True,
post_prune_hooks=(sync_vit_attention_metadata,),
)
print(tracker.view_structures())
tracker.create_tracker("structured_bops", log_total_bops=True)
metrics_before = tracker.track()
# Example: zero two attention heads in the group reported by view_structures().
tracker.fake_prune_unit(group_id=3, unit_id=0)
tracker.fake_prune_unit(group_id=3, unit_id=2)
metrics_after_fake_prune = tracker.track()
# Convert zeroed units into real shape changes, then recreate metric trackers.
tracker.prune_zero_units()
metrics_after_physical_prune = tracker.create_tracker(
"structured_bops",
log_total_bops=True,
).track()
print(metrics_before["structured_bops"])
print(metrics_after_fake_prune["structured_bops"])
print(metrics_after_physical_prune["structured_bops"])
For timm ViTs, head pruning removes complete q/k/v head slices from the fused
qkv projection and the corresponding projection input channels. The sync hook
keeps num_heads, attn_dim, head_dim, and scale consistent with the new
shape so the pruned model can still run a forward pass.
Group Lasso
Structured group lasso regularizes coupled units together. Layers can be excluded per regularizer:
from torch_weighttracker.regularizers import RegularizerType
group_lasso = tracker.create_regularizer(
RegularizerType.GROUP_LASSO,
ignore=[model.classifier],
)
loss = task_loss + 1e-4 * group_lasso()
loss.backward()
Structured BOPs
Structured BOPs reports compression against a dense 32-bit baseline by default. BOP trackers exclude normalization modules such as batchnorm, layernorm, groupnorm, and instancenorm layers by default:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.STRUCTURED_BOPS,
include=[model.layer3, model.layer4],
).track()
print(metrics["structured_bops"]["compression"])
raw_metrics = tracker.create_tracker(
TrackerType.STRUCTURED_BOPS,
include=[model.layer3, model.layer4],
log_total_bops=True,
log_layerwise_stats=True,
).track()
structured = raw_metrics["structured_bops"]
print(structured["bops"])
print(structured["modules"])
create_tracker accepts a single TrackerType/string or a list of tracker
types/strings:
tracker.create_tracker(
[TrackerType.STRUCTURED_BOPS, "group_pruning_summary"]
)
metrics = tracker.track()
Tracker metrics are nested by tracker name and convert tensor values to Python
numbers/lists by default. Pass convert_tensors=False to preserve tensors in
the returned metrics.
Formulation of the Structured BOPs Metric
For each weighted module $m$, WeightTracker multiplies the active structured MAC count by that module's activation and weight bit widths [1]:
$$ \mathit{StructuredBOPs}_m = \mathit{ActiveMACs}_m \cdot b^{\mathrm{act}}_m \cdot b^{\mathrm{weight}}_m $$
The active MAC count scales the dense module MAC count by the active fraction of each structural cost axis:
$$ \mathit{ActiveMACs}m = \mathit{BaselineMACs}m \cdot \prod{a \in A_m} \frac{n^{\mathrm{active}}{m,a}}{n^{\mathrm{baseline}}_{m,a}} $$
Compression is reported against a dense 32-bit activation and 32-bit weight baseline:
$$ \mathit{BaselineBOPs}_m = \mathit{BaselineMACs}_m \cdot 32 \cdot 32 $$
$$ \mathit{CompressionRate} = 1 - \frac{\sum_m \mathit{StructuredBOPs}_m} {\sum_m \mathit{BaselineBOPs}_m} $$
Where:
- $\mathit{StructuredBOPs}_m$: active bit operations for weighted module $m$.
- $\mathit{ActiveMACs}_m$: active MAC count after structured units are masked or pruned.
- $\mathit{BaselineMACs}_m$: dense MAC count for module $m$ before structured pruning.
- $A_m$: structural cost axes for module $m$, such as input and output channel axes.
- $n^{\mathrm{active}}_{m,a}$: active size of cost axis $a$ for module $m$.
- $n^{\mathrm{baseline}}_{m,a}$: dense baseline size of cost axis $a$ for module $m$.
- $b^{\mathrm{act}}_m$: activation bit width for module $m$.
- $b^{\mathrm{weight}}_m$: weight bit width for module $m$.
Comparison with Direct Removal and FLOP Count
For some model architectures, the BOPs calculation may differ from values reported by other libraries. These differences mainly come from which layers and operations are included. WeightTracker does not count elementwise operations such as ReLU activations or bias terms.
The repository includes sanity notebooks comparing fvcore.FlopCountAnalysis
on physically pruned models with WeightTracker on fake-pruned models, where
weights are zeroed to match the equivalent hard-pruned structure.
Local sanity notebooks compare WeightTracker MAC accounting with physically pruned models from Torch-Pruning. These dependencies are optional and are not installed with the base package:
python -m pip install -e ".[dev-local]"
Then start Jupyter from the repository root and open the notebooks in
sanity_checks/.
Unstructured Sparsity
Unstructured sparsity reports exact zero-weight fractions. The total is weighted by each layer's number of weight elements, not averaged across layer fractions:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.UNSTRUCTURED_SPARSITY,
include=[model.layer3, model.layer4],
ignore=[torch.nn.BatchNorm2d],
).track()
sparsity = metrics["unstructured_sparsity"]
print(sparsity["sparsity"])
print(sparsity["layers"])
Values are fractions in [0, 1]. Parametrized fake quantization is measured
through the effective module.weight, so quantized zeros count as sparse
weights.
Unstructured BOPs
Unstructured BOPs combines each layer's dense runtime MAC count with its active weight fraction and activation/weight bit widths:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.UNSTRUCTURED_BOPS,
include=[model.layer3, model.layer4],
).track()
print(metrics["unstructured_bops"]["compression"])
raw_metrics = tracker.create_tracker(
TrackerType.UNSTRUCTURED_BOPS,
include=[model.layer3, model.layer4],
log_total_bops=True,
log_layerwise_stats=True,
).track()
unstructured = raw_metrics["unstructured_bops"]
print(unstructured["bops"])
print(unstructured["modules"])
For each weighted module $m$:
$$ \mathit{UnstructuredBOPs}_m = \mathit{BaselineMACs}_m \cdot (1 - \mathit{Sparsity}_m) \cdot b^{\mathrm{act}}_m \cdot b^{\mathrm{weight}}_m $$
Compression uses the same dense 32-bit baseline as structured BOPs:
$$ \mathit{CompressionRate} = 1 - \frac{\sum_m \mathit{UnstructuredBOPs}_m} {\sum_m \mathit{BaselineMACs}_m \cdot 32 \cdot 32} $$
NVIDIA 2:4 Sparsity
NVIDIA 2:4 sparsity reports block eligibility for supported weighted layers.
Linear and MultiheadAttention projection weights are grouped in contiguous
blocks of four along the input axis. Convolution weights shaped [K, C, ...]
are grouped along C for each output/spatial position.
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.NVIDIA_2_4_SPARSITY,
include=[model.layer3, model.layer4],
ignore=[torch.nn.BatchNorm2d],
log_layerwise_stats=True,
).track()
nvidia_24 = metrics["nvidia_2_4_sparsity"]
print(nvidia_24["strict_block_fraction"])
print(nvidia_24["nvidia_eligible_block_fraction"])
print(nvidia_24["tail_elements"])
The strict fraction counts complete 4-value blocks with exactly two zeros. The NVIDIA-eligible fraction counts blocks with at least two zeros, matching the TensorRT eligibility rule. Tail elements are reported separately and prevent a layer from counting as strict or eligible.
Group Pruning Summary
Group pruning summary reports pruned canonical units and group-attributed pruned parameters in a nested tracker metrics dictionary:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.GROUP_PRUNING_SUMMARY,
include=[model.layer3, model.layer4],
ignore=[torch.nn.BatchNorm2d],
).track()
summary = metrics["group_pruning_summary"]
print(summary["pruned_units"])
print(summary["pruned_params"])
Per-group values are emitted under
summary["groups"]["layer3.0.conv1:prune_out_channels"]["pruned_units"]
and the corresponding "pruned_params" key.
Architecture
The main API is WeightTracker. Internally it is split into a few layers:
- Dependency discovery:
WeightTrackerbuilds dependency groups from the model andexample_inputsusing Torch-Pruning's dependency graph machinery [2], whose work we gratefully build on. - Canonical units:
canonical_units.pynormalizes raw dependency groups intoCanonicalUnitGroupobjects. These give channels, features, attention heads, and head dimensions a shared unit index. - Reduction plans:
reductions/andplans/compile module and unit mappings into segment and index operations that use PyTorch's efficient tensor computations. - Calculations:
calculations/defines named calculation specs such as per-unit L2 norm, active units, parameters per unit, active MACs, and bitrates. Calculations can depend on each other and cache constant results. - Consumers:
regularizers/andtrackers/request the calculations they need, optionally withincludeandignorecontexts for selecting modules in a specific metric or regularizer.
The result is a small public surface with a reusable internal graph:
model + example inputs
|
v
dependency groups -> canonical units -> reduction plans -> calculations
|
v
regularizers and trackers
Speed
Compared with a naive implementation, the current implementation gives the following speedups on ResNet 20 on a RTX 3060:
| Comparison | Speedup | Naive extra allocation | WeightTracker extra allocation |
|---|---|---|---|
| Group lasso | 15.421x | 197.0MiB | 197.0MiB |
| Structured BOPs | 2.582x | 1.7GiB | 195.9MiB |
Status
This package is pre-1.0. Public APIs may still change while the tracker, calculation, and regularizer surfaces settle.
Future Work
- Streamline definitions and method names across the codebase.
- Improve calculation caching so repeated computations are not performed twice.
- Improve compilation of computation plans for bigger speedups.
- Improve memory management within calculations.
- Write more comprehensive docstrings.
Future custom use cases will need a broader top-level WeightTracker API for
custom operations, custom layers, and generic group definitions.
License
MIT
References
[1] Wang et al., Differentiable Joint Pruning and Quantization for Hardware Efficiency, 2020.
[2] Fang et al., Torch-Pruning.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_weighttracker-0.2.3.tar.gz.
File metadata
- Download URL: torch_weighttracker-0.2.3.tar.gz
- Upload date:
- Size: 210.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bfba010bdd97fce0f6f1e7b162c1dead907ebc47c3899021cb96e1cfe5a04ff2
|
|
| MD5 |
a866585739cb3d0573cd4725196e3d68
|
|
| BLAKE2b-256 |
9ef09e9ce5166ceb148f8cde9692e3aaa43a120d9f06e1058ca7be4393600a25
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.2.3.tar.gz:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.2.3.tar.gz -
Subject digest:
bfba010bdd97fce0f6f1e7b162c1dead907ebc47c3899021cb96e1cfe5a04ff2 - Sigstore transparency entry: 1870690544
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@09b94d9886f0ff202b7b20fd978db11d39ec4a0d -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@09b94d9886f0ff202b7b20fd978db11d39ec4a0d -
Trigger Event:
push
-
Statement type:
File details
Details for the file torch_weighttracker-0.2.3-py3-none-any.whl.
File metadata
- Download URL: torch_weighttracker-0.2.3-py3-none-any.whl
- Upload date:
- Size: 118.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8f6c7d04bc3f8a60d5c41fe8491a05e7e97b512c4d56dbad8847443f37f7b6ac
|
|
| MD5 |
ed15717076c03d84a2b46f8165580924
|
|
| BLAKE2b-256 |
20f10677c4cd83142e8ce508eb5547fc0468518dd7b499034116ebf03359c3a8
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.2.3-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.2.3-py3-none-any.whl -
Subject digest:
8f6c7d04bc3f8a60d5c41fe8491a05e7e97b512c4d56dbad8847443f37f7b6ac - Sigstore transparency entry: 1870690570
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@09b94d9886f0ff202b7b20fd978db11d39ec4a0d -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@09b94d9886f0ff202b7b20fd978db11d39ec4a0d -
Trigger Event:
push
-
Statement type: