Step-Decomposed Influence (SDI) and fast TracIn for (looped) transformers.
Project description
Step-Decomposed Influence (SDI)
Step-Decomposed Influence (SDI) decomposes TracIn into a step-resolved influence trajectory for looped/weight-tied models. This reference implementation computes per-step influence while avoiding per-example gradient materialization by sketching during backprop (TensorSketch for 2D weights, CountSketch for 1D params).
Paper: https://arxiv.org/abs/2602.10097
Install (UV-first)
# From PyPI
uv add step-decomposed-influence
# From source (editable, in a UV project)
uv add --editable .
CPU vs CUDA:
- This package is device-agnostic. CUDA works automatically if you install a CUDA-enabled PyTorch.
- Follow the official PyTorch install instructions to choose the right wheel for your system.
Quickstart
import torch
import torch.nn.functional as F
from sdi import ProjectedTracInSDI, CheckpointSpec
# model = your looped transformer
# target_modules = the recurrent core used at each loop step
def loss_fn(model, batch):
logits = model(batch["tokens"]) # (B, T, vocab)
targets = batch["tokens"][:, 1:]
logits = logits[:, :-1, :]
per_pos = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
reduction="none",
).view(targets.size(0), -1)
return per_pos.mean(dim=1) # per-example loss (B,)
est = ProjectedTracInSDI(
model=model,
target_modules=target_modules,
projection_size=2048,
loss_reduction="sum",
)
out = est.influence_across_checkpoints(
checkpoints=[CheckpointSpec("checkpoints/ckpt_0001.pt")],
train_loader=train_loader,
query_loader=query_loader,
loss_fn=loss_fn,
mode="sdi",
train_chunk_size=128,
query_chunk_size=128,
)
# out.sdi: (N_train, N_query, steps_query)
# out.tracin: (N_train, N_query)
Example
Run the toy looped-transformer demo:
uv run python examples/toy_looped_transformer_sdi.py --device cuda
This script:
- trains a tiny looped transformer on random data,
- saves a few checkpoints,
- computes projected SDI (sketch-during-backprop),
- and shows fast TracIn (scalar-only) via
mode="tracin".
API (minimal)
ProjectedTracInSDI: projected SDI/TracIn using TensorSketch + CountSketch.FullGradientTracInSDI: exact baseline (materializes per-sample gradients; small models only).CheckpointSpec: path + optional checkpoint weight (eta).InfluenceOutputs:sdi(None whenmode="tracin"),tracin, and optionalsdi_matrix.
Assumptions & limitations
- Steps are inferred from how often the selected parameters are used in a single forward/backward. For looped transformers, this equals the loop horizon when you pass only the looped core.
- Projected estimator supports:
- 2D params:
nn.Linear.weight,nn.Embedding.weight(TensorSketch) - 1D params: CountSketch of per-sample grads via Opacus grad samplers (plus analytic
nn.Linear.bias)
- 2D params:
- Custom modules without Opacus grad-sampler coverage are not supported in strict mode.
- Checkpoint loading uses
torch.load. Do not load untrusted checkpoints.
Citation
For citing this software repository, use GitHub's "Cite this repository" button.
For citing the paper, use:
@article{kaissis2026step,
title = {Step-resolved data attribution for looped transformers},
author = {Georgios Kaissis and David Mildenberger and Juan Felipe Gomez and Martin J. Menten and Eleni Triantafillou},
year = {2026},
journal = {arXiv preprint arXiv:2602.10097},
eprint = {2602.10097},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2602.10097},
}
Authors
- Georgios Kaissis (Hasso-Plattner Institute)
- David Mildenberger (Technical University of Munich)
- Juan Felipe Gomez (Harvard University)
- Martin J. Menten (Technical University of Munich)
- Eleni Triantafillou (Google DeepMind)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file step_decomposed_influence-0.1.1.tar.gz.
File metadata
- Download URL: step_decomposed_influence-0.1.1.tar.gz
- Upload date:
- Size: 22.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d74638870f09a22d8107c4e9522d3d5bce87ae9860603ebfa51e49c59c989a78
|
|
| MD5 |
2ad68b28b55b0a7d26b8fb391704a49e
|
|
| BLAKE2b-256 |
9f0adb2c8aee0d1fa9ff347fea1dbb78048a52967b78fb5a6f6fc6172a3df879
|
File details
Details for the file step_decomposed_influence-0.1.1-py3-none-any.whl.
File metadata
- Download URL: step_decomposed_influence-0.1.1-py3-none-any.whl
- Upload date:
- Size: 20.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e23a86a147bbf1feedfc24341ff29a7c62197b5bb96addfea45ec9483c0534a2
|
|
| MD5 |
d4c556343e20c3a784dc63fbafdddc62
|
|
| BLAKE2b-256 |
d5086a45d0d85cd44f2b67a58b141113aac8478b0a63487e92e7018c8db5efed
|