Skip to main content

Tools for tracking structured weight sparsity in PyTorch models.

Project description

torch-weighttracker

Tools for tracking structured weight sparsity, regularization signals, and bit-operation estimates in PyTorch models.

The package builds a structural view of a model, compiles tensorized reduction plans over that structure, and reuses those plans for training-time metrics and regularizers.

import torch
from torch import nn

from torch_weighttracker import WeightTracker

model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))

print(tracker.view_structures())

TRANSFORMS NOT FULLY SUPPORTED YET

Installation

python -m pip install torch-weighttracker

Structured BOPs MAC accounting uses fvcore for baseline per-module MACs:

python -m pip install "torch-weighttracker[structured-bops]"

Why Use It?

PyTorch makes it easy to inspect one parameter tensor at a time. Structured compression often needs a different view:

  • A channel can be coupled across convolutions, batch norms, linear layers, and residual paths.
  • A transformer unit can mean an attention head, a head dimension, or a fused QKV slice rather than a simple row or column.
  • A metric such as "active BOPs" depends on sparsity, module shape, MAC counts, and bitrates at the same time.
  • A regularizer such as group lasso should penalize the coupled structural unit, not each weight tensor independently.

WeightTracker turns those coupled structures into canonical units, then lets calculations operate over the canonical units with reusable tensor programs.

Use Cases

Current use cases:

  • Add structured group lasso to a training loss.
  • Track active structured BOPs and compression rate during structured pruning, sparsity-aware training, or quantization-aware training (QAT).
  • Inspect which modules participate in each channel, feature, head, or head-dim group.
  • Build structural metrics that aggregate many weight tensors into one value per pruning unit.

Group Lasso

Structured group lasso regularizes coupled units together. Layers can be excluded per regularizer:

from torch_weighttracker.regularizers import RegularizerType

group_lasso = tracker.create_regularizer(
    RegularizerType.GROUP_LASSO,
    ignore=[model.classifier],
)

loss = task_loss + 1e-4 * group_lasso()
loss.backward()

Structured BOPs

Structured BOPs reports compression against a dense 32-bit baseline by default:

import torch

from torch_weighttracker.trackers import TrackerType

metrics = tracker.create_tracker(
    TrackerType.STRUCTURED_BOPS,
    include=[model.layer3, model.layer4],
    ignore=[torch.nn.BatchNorm2d],
).track()

print(metrics["structured_bops_compression"])

raw_metrics = tracker.create_tracker(
    TrackerType.STRUCTURED_BOPS,
    include=[model.layer3, model.layer4],
    ignore=[torch.nn.BatchNorm2d],
    log_total_bops=True,
    log_layerwise_stats=True,
).track()

print(raw_metrics["structured_bops"])
print(raw_metrics["structured_bops_pr_module"])
print(raw_metrics["structured_bops_compression_rate_pr_module"])

create_tracker accepts a single TrackerType/string or a list of tracker types/strings:

tracker.create_tracker([TrackerType.STRUCTURED_BOPS, "unstructured_sparsity"])
metrics = tracker.track()

Formulation of the Structured BOPs Metric

For each weighted module $m$, WeightTracker multiplies the active structured MAC count by that module's activation and weight bit widths [1]:

$$ \operatorname{StructuredBOPs}_m = \operatorname{ActiveMACs}_m \cdot b^{\mathrm{act}}_m \cdot b^{\mathrm{weight}}_m $$

The active MAC count scales the dense module MAC count by the active fraction of each structural cost axis:

$$ \operatorname{ActiveMACs}m = \operatorname{BaselineMACs}m \cdot \prod{a \in A_m} \frac{n^{\mathrm{active}}{m,a}}{n^{\mathrm{baseline}}_{m,a}} $$

Compression is reported against a dense 32-bit activation and 32-bit weight baseline:

$$ \operatorname{BaselineBOPs}_m = \operatorname{BaselineMACs}_m \cdot 32 \cdot 32 $$

$$ \operatorname{CompressionRate} = 1 - \frac{\sum_m \operatorname{StructuredBOPs}_m} {\sum_m \operatorname{BaselineBOPs}_m} $$

Where:

  • $\operatorname{StructuredBOPs}_m$: active bit operations for weighted module $m$.
  • $\operatorname{ActiveMACs}_m$: active MAC count after structured units are masked or pruned.
  • $\operatorname{BaselineMACs}_m$: dense MAC count for module $m$ before structured pruning.
  • $A_m$: structural cost axes for module $m$, such as input and output channel axes.
  • $n^{\mathrm{active}}_{m,a}$: active size of cost axis $a$ for module $m$.
  • $n^{\mathrm{baseline}}_{m,a}$: dense baseline size of cost axis $a$ for module $m$.
  • $b^{\mathrm{act}}_m$: activation bit width for module $m$.
  • $b^{\mathrm{weight}}_m$: weight bit width for module $m$.

Comparison with Direct Removal and FLOP Count

For some model architectures, the BOPs calculation may differ from values reported by other libraries. These differences mainly come from which layers and operations are included. WeightTracker does not count elementwise operations such as ReLU activations or bias terms.

The repository includes sanity notebooks comparing fvcore.FlopCountAnalysis on physically pruned models with WeightTracker on fake-pruned models, where weights are zeroed to match the equivalent hard-pruned structure.

Local sanity notebooks compare WeightTracker MAC accounting with physically pruned models from Torch-Pruning. These dependencies are optional and are not installed with the base package:

python -m pip install -e ".[dev-local]"

Then start Jupyter from the repository root and open the notebooks in sanity_checks/.

Unstructured Sparsity

Unstructured sparsity reports exact zero-weight fractions. The total is weighted by each layer's number of weight elements, not averaged across layer fractions:

import torch

from torch_weighttracker.trackers import TrackerType

metrics = tracker.create_tracker(
    TrackerType.UNSTRUCTURED_SPARSITY,
    include=[model.layer3, model.layer4],
    ignore=[torch.nn.BatchNorm2d],
).track()

print(metrics["unstructured_sparsity"])
print(metrics["layers"])

Values are fractions in [0, 1]. Parametrized fake quantization is measured through the effective module.weight, so quantized zeros count as sparse weights.

Architecture

The main API is WeightTracker. Internally it is split into a few layers:

  1. Dependency discovery: WeightTracker builds dependency groups from the model and example_inputs, or accepts precomputed groups.
  2. Canonical units: canonical_units.py normalizes raw dependency groups into CanonicalUnitGroup objects. These give channels, features, attention heads, and head dimensions a shared unit index.
  3. Reduction plans: reductions/ and plans/ compile module and unit mappings into segment and index operations that use PyTorch's efficient tensor computations.
  4. Calculations: calculations/ defines named calculation specs such as per-unit L2 norm, active units, parameters per unit, active MACs, and bitrates. Calculations can depend on each other and cache constant results.
  5. Consumers: regularizers/ and trackers/ request the calculations they need, optionally with include and ignore contexts for selecting modules in a specific metric or regularizer.

The result is a small public surface with a reusable internal graph:

model + example inputs
        |
        v
dependency groups -> canonical units -> reduction plans -> calculations
                                                          |
                                                          v
                                               regularizers and trackers

Speed

Compared with a naive implementation, the current implementation gives the following speedups on ResNet 20 on a RTX 3060:

Comparison Speedup Naive extra allocation WeightTracker extra allocation
Group lasso 15.421x 197.0MiB 197.0MiB
Structured BOPs 2.582x 1.7GiB 195.9MiB

Status

This package is pre-1.0. Public APIs may still change while the tracker, calculation, and regularizer surfaces settle.

Future Work

  1. Streamline definitions and method names across the codebase.
  2. Improve calculation caching so repeated computations are not performed twice.
  3. Improve compilation of computation plans for bigger speedups.
  4. Improve memory management within calculations.
  5. Write more comprehensive docstrings.

Future custom use cases will need a broader top-level WeightTracker API for custom operations, custom layers, and generic group definitions.

License

MIT

References

[1] Wang et al., Differentiable Joint Pruning and Quantization for Hardware Efficiency, 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_weighttracker-0.1.6.tar.gz (190.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_weighttracker-0.1.6-py3-none-any.whl (102.1 kB view details)

Uploaded Python 3

File details

Details for the file torch_weighttracker-0.1.6.tar.gz.

File metadata

  • Download URL: torch_weighttracker-0.1.6.tar.gz
  • Upload date:
  • Size: 190.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for torch_weighttracker-0.1.6.tar.gz
Algorithm Hash digest
SHA256 301836faf7c462881225952749ffffcb826b850ebc8beb50f5687d14ddc71aa1
MD5 4490e535d9c2daeb2360362c4ec3fca4
BLAKE2b-256 8a223fa9e4fa0bd39a3c5b51d39bfa212aefbc9a06036f951bd69dfd4fe2beaf

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_weighttracker-0.1.6.tar.gz:

Publisher: publish-to-pypi.yml on dadyownes15/torch-weighttracker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torch_weighttracker-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_weighttracker-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 cb424fb9ae142d899270385f598bd17bcee4f1c0fe368f800f6f66d219c0072c
MD5 224abd127518ff5447f046da80306219
BLAKE2b-256 d96fdf2eaaa3ef1b269a618e8461b8f51734c8e8fbfe8e0748be7393b04567de

See more details on using hashes here.

Provenance

The following attestation bundles were made for torch_weighttracker-0.1.6-py3-none-any.whl:

Publisher: publish-to-pypi.yml on dadyownes15/torch-weighttracker

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page