Tools for tracking structured weight sparsity in PyTorch models.
Project description
torch-weighttracker
Package for tracking structured weight sparsity, regularization signals, and bit-operation estimates in torch modules.
The API is centered on WeightTracker:
import torch
from torch import nn
from torch_weighttracker import WeightTracker
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
tracker = WeightTracker(model, example_inputs=torch.randn(1, 4))
print(tracker.view_structures())
Installation
python -m pip install torch-weighttracker
Structured BOPs MAC accounting uses fvcore for baseline per-module MACs:
python -m pip install "torch-weighttracker[structured-bops]"
Tensorized cross weight operations
Weighttracker builds an interface for doing cross weight tensor operations on models efficiently. Using the "Computation plans" and "Calculation" classes, we compile a set of torch modules which execute tensors and mapping operations on training device with the "minimal" set of repeated operations.
Use case
Weighttrackers primary use case for now is for calculating structural depedency based loss terms & metric evaluations, such as structured sparsity & structured compression rates, and group lasso. However, the code has been made such that it can be used for any weight traversering operations with some modifications.
Group lasso
Structured group lasso regularizes coupled units together. Layers can be excluded per regularizer:
from torch_weighttracker.regularizers import RegularizerType
group_lasso = tracker.create_regularizer(
RegularizerType.GROUP_LASSO,
ignore=[model.classifier],
)
loss = task_loss + 1e-4 * group_lasso()
loss.backward()
Structured BOPs
Structured BOPs compares active bit operations against a dense 32-bit baseline:
import torch
from torch_weighttracker.trackers import TrackerType
metrics = tracker.create_tracker(
TrackerType.STRUCTURED_BOPS,
ignore=[torch.nn.BatchNorm2d],
log_compression_rate=True,
).track()
print(metrics["structured_bops"])
print(metrics["structured_bops_baseline"])
print(metrics["structured_bops_compression_rate"])
Speed
Comparing with a naive implementation we get the following speed ups:
- Group lasso: 15.503x
- Naive: 4.6540s total, 232.698ms/step
- Weighttracker: 0.3002s total, 15.010ms/step
- Structured BOPs: 2.531x
- Naive: 0.6757s total, 33.783ms/step
- Weighttracker: 0.2669s total, 13.346ms/step
| Comparison | Speedup | Naive extra alloc | Weighttracker extra alloc |
|---|---|---|---|
| Group lasso | 15.421x | 197.0MiB | 197.0MiB |
| Structured BOPs | 2.582x | 1.7GiB | 195.9MiB |
Status
This package is pre-1.0. Public APIs may still change while the tracker, calculation, and regularizer surfaces settle.
Future work:
- Streamlining defintions and methods across the code for a unified and more compressed and perhabs more understandable API.
- Implement Calculation caching, such that computations are not computed twice
- Improve compilations of computation plans
- Improve memory management within calculations
- Write more comprehensive docstrings
For future use cases, an update of the toplevel API WeightTracker is needed, including the ability to input custom operations, custom layers, generic group defintions etc.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_weighttracker-0.1.2.tar.gz.
File metadata
- Download URL: torch_weighttracker-0.1.2.tar.gz
- Upload date:
- Size: 178.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5c4cff95b514d64ee44a13a5d2bceb831d912577fb9b7d7e00d2d9e81aa72623
|
|
| MD5 |
72da107d95d3f51ca3114ab29edf787b
|
|
| BLAKE2b-256 |
63dfefc350ed2524b235fbf894ca10aa45ba4506841d4ca0f5994b6ba2170fa3
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.1.2.tar.gz:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.1.2.tar.gz -
Subject digest:
5c4cff95b514d64ee44a13a5d2bceb831d912577fb9b7d7e00d2d9e81aa72623 - Sigstore transparency entry: 1553169366
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@b249043c74fa96252385974b105f9735a712be33 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@b249043c74fa96252385974b105f9735a712be33 -
Trigger Event:
push
-
Statement type:
File details
Details for the file torch_weighttracker-0.1.2-py3-none-any.whl.
File metadata
- Download URL: torch_weighttracker-0.1.2-py3-none-any.whl
- Upload date:
- Size: 92.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a20e5fd0bc50f6e924872ccf300cd3035f421abae54e09301b27de4a3fe1d447
|
|
| MD5 |
a3b32e63dcb95329011b328dd842bdba
|
|
| BLAKE2b-256 |
b630a2c4d586755df1f3cf797a152b231817780cef2cd971cb217957d297b005
|
Provenance
The following attestation bundles were made for torch_weighttracker-0.1.2-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on dadyownes15/torch-weighttracker
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torch_weighttracker-0.1.2-py3-none-any.whl -
Subject digest:
a20e5fd0bc50f6e924872ccf300cd3035f421abae54e09301b27de4a3fe1d447 - Sigstore transparency entry: 1553169382
- Sigstore integration time:
-
Permalink:
dadyownes15/torch-weighttracker@b249043c74fa96252385974b105f9735a712be33 -
Branch / Tag:
refs/tags/v0.1.2 - Owner: https://github.com/dadyownes15
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@b249043c74fa96252385974b105f9735a712be33 -
Trigger Event:
push
-
Statement type: