Skip to main content

Step-Decomposed Influence (SDI) and fast TracIn for (looped) transformers.

Project description

Step-Decomposed Influence (SDI)

Step-Decomposed Influence (SDI) decomposes TracIn into a step-resolved influence trajectory for looped/weight-tied models. This reference implementation computes per-step influence while avoiding per-example gradient materialization by sketching during backprop (TensorSketch for 2D weights, CountSketch for 1D params).

Paper: https://arxiv.org/abs/2602.10097

Install (UV-first)

# From PyPI
uv add step-decomposed-influence

# From source (editable, in a UV project)
uv add --editable .

CPU vs CUDA:

  • This package is device-agnostic. CUDA works automatically if you install a CUDA-enabled PyTorch.
  • Follow the official PyTorch install instructions to choose the right wheel for your system.

Quickstart

import torch
import torch.nn.functional as F
from sdi import ProjectedTracInSDI, CheckpointSpec

# model = your looped transformer
# target_modules = the recurrent core used at each loop step

def loss_fn(model, batch):
    logits = model(batch["tokens"])  # (B, T, vocab)
    targets = batch["tokens"][:, 1:]
    logits = logits[:, :-1, :]
    per_pos = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
    ).view(targets.size(0), -1)
    return per_pos.mean(dim=1)  # per-example loss (B,)

est = ProjectedTracInSDI(
    model=model,
    target_modules=target_modules,
    projection_size=2048,
    loss_reduction="sum",
)

out = est.influence_across_checkpoints(
    checkpoints=[CheckpointSpec("checkpoints/ckpt_0001.pt")],
    train_loader=train_loader,
    query_loader=query_loader,
    loss_fn=loss_fn,
    mode="sdi",
    train_chunk_size=128,
    query_chunk_size=128,
)

# out.sdi:    (N_train, N_query, steps_query)
# out.tracin: (N_train, N_query)

Example

Run the toy looped-transformer demo:

uv run python examples/toy_looped_transformer_sdi.py --device cuda

This script:

  • trains a tiny looped transformer on random data,
  • saves a few checkpoints,
  • computes projected SDI (sketch-during-backprop),
  • and shows fast TracIn (scalar-only) via mode="tracin".

API (minimal)

  • ProjectedTracInSDI: projected SDI/TracIn using TensorSketch + CountSketch.
  • FullGradientTracInSDI: exact baseline (materializes per-sample gradients; small models only).
  • CheckpointSpec: path + optional checkpoint weight (eta).
  • InfluenceOutputs: sdi (None when mode="tracin"), tracin, and optional sdi_matrix.

Assumptions & limitations

  • Steps are inferred from how often the selected parameters are used in a single forward/backward. For looped transformers, this equals the loop horizon when you pass only the looped core.
  • Projected estimator supports:
    • 2D params: nn.Linear.weight, nn.Embedding.weight (TensorSketch)
    • 1D params: CountSketch of per-sample grads via Opacus grad samplers (plus analytic nn.Linear.bias)
  • Custom modules without Opacus grad-sampler coverage are not supported in strict mode.
  • Checkpoint loading uses torch.load. Do not load untrusted checkpoints.

Citation

For citing this software repository, use GitHub's "Cite this repository" button.

For citing the paper, use:

@article{kaissis2026step,
  title         = {Step-resolved data attribution for looped transformers},
  author        = {Georgios Kaissis and David Mildenberger and Juan Felipe Gomez and Martin J. Menten and Eleni Triantafillou},
  year          = {2026},
  journal       = {arXiv preprint arXiv:2602.10097},
  eprint        = {2602.10097},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2602.10097},
}

Authors

  • Georgios Kaissis (Hasso-Plattner Institute)
  • David Mildenberger (Technical University of Munich)
  • Juan Felipe Gomez (Harvard University)
  • Martin J. Menten (Technical University of Munich)
  • Eleni Triantafillou (Google DeepMind)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

step_decomposed_influence-0.1.1.tar.gz (22.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

step_decomposed_influence-0.1.1-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file step_decomposed_influence-0.1.1.tar.gz.

File metadata

  • Download URL: step_decomposed_influence-0.1.1.tar.gz
  • Upload date:
  • Size: 22.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for step_decomposed_influence-0.1.1.tar.gz
Algorithm Hash digest
SHA256 d74638870f09a22d8107c4e9522d3d5bce87ae9860603ebfa51e49c59c989a78
MD5 2ad68b28b55b0a7d26b8fb391704a49e
BLAKE2b-256 9f0adb2c8aee0d1fa9ff347fea1dbb78048a52967b78fb5a6f6fc6172a3df879

See more details on using hashes here.

File details

Details for the file step_decomposed_influence-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: step_decomposed_influence-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for step_decomposed_influence-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e23a86a147bbf1feedfc24341ff29a7c62197b5bb96addfea45ec9483c0534a2
MD5 d4c556343e20c3a784dc63fbafdddc62
BLAKE2b-256 d5086a45d0d85cd44f2b67a58b141113aac8478b0a63487e92e7018c8db5efed

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page