Skip to main content

Step-Decomposed Influence (SDI) and fast TracIn for (looped) transformers.

Project description

Step-Decomposed Influence (SDI)

Step-Decomposed Influence (SDI) decomposes TracIn into a step-resolved influence trajectory for looped/weight-tied models. This reference implementation computes per-step influence while avoiding per-example gradient materialization by sketching during backprop (TensorSketch for 2D weights, CountSketch for 1D params).

Paper (placeholder): https://arxiv.org/abs/XXXX.XXXXX

Install (UV-first)

# From PyPI
uv add step-decomposed-influence

# From source (editable, in a UV project)
uv add --editable .

CPU vs CUDA:

  • This package is device-agnostic. CUDA works automatically if you install a CUDA-enabled PyTorch.
  • Follow the official PyTorch install instructions to choose the right wheel for your system.

Quickstart

import torch
import torch.nn.functional as F
from sdi import ProjectedTracInSDI, CheckpointSpec

# model = your looped transformer
# target_modules = the recurrent core used at each loop step

def loss_fn(model, batch):
    logits = model(batch["tokens"])  # (B, T, vocab)
    targets = batch["tokens"][:, 1:]
    logits = logits[:, :-1, :]
    per_pos = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
    ).view(targets.size(0), -1)
    return per_pos.mean(dim=1)  # per-example loss (B,)

est = ProjectedTracInSDI(
    model=model,
    target_modules=target_modules,
    projection_size=2048,
    loss_reduction="sum",
)

out = est.influence_across_checkpoints(
    checkpoints=[CheckpointSpec("checkpoints/ckpt_0001.pt")],
    train_loader=train_loader,
    query_loader=query_loader,
    loss_fn=loss_fn,
    mode="sdi",
    train_chunk_size=128,
    query_chunk_size=128,
)

# out.sdi:    (N_train, N_query, steps_query)
# out.tracin: (N_train, N_query)

Example

Run the toy looped-transformer demo:

uv run python examples/toy_looped_transformer_sdi.py --device cuda

This script:

  • trains a tiny looped transformer on random data,
  • saves a few checkpoints,
  • computes projected SDI (sketch-during-backprop),
  • and shows fast TracIn (scalar-only) via mode="tracin".

API (minimal)

  • ProjectedTracInSDI: projected SDI/TracIn using TensorSketch + CountSketch.
  • FullGradientTracInSDI: exact baseline (materializes per-sample gradients; small models only).
  • CheckpointSpec: path + optional checkpoint weight (eta).
  • InfluenceOutputs: sdi (None when mode="tracin"), tracin, and optional sdi_matrix.

Assumptions & limitations

  • Steps are inferred from how often the selected parameters are used in a single forward/backward. For looped transformers, this equals the loop horizon when you pass only the looped core.
  • Projected estimator supports:
    • 2D params: nn.Linear.weight, nn.Embedding.weight (TensorSketch)
    • 1D params: CountSketch of per-sample grads via Opacus grad samplers (plus analytic nn.Linear.bias)
  • Custom modules without Opacus grad-sampler coverage are not supported in strict mode.
  • Checkpoint loading uses torch.load. Do not load untrusted checkpoints.

Citation

For citing this software repository, use GitHub's "Cite this repository" button.

For citing the paper, use:

@article{kaissis2026sdi,
  title  = {Step-Resolved Data Attribution for Looped Transformers},
  author = {Georgios Kaissis and David Mildenberger and Felipe Gomez and Martin Menten and Eleni Triantafillou},
  year   = {2026},
  note   = {arXiv preprint},
}

Authors

  • Georgios Kaissis (Hasso-Plattner Institute)
  • David Mildenberger (Technical University of Munich)
  • Felipe Gomez (Harvard University)
  • Martin Menten (Technical University of Munich)
  • Eleni Triantafillou (Google DeepMind)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

step_decomposed_influence-0.1.0.tar.gz (22.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

step_decomposed_influence-0.1.0-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file step_decomposed_influence-0.1.0.tar.gz.

File metadata

  • Download URL: step_decomposed_influence-0.1.0.tar.gz
  • Upload date:
  • Size: 22.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for step_decomposed_influence-0.1.0.tar.gz
Algorithm Hash digest
SHA256 685684215648ed2f4a852a292962a8a150948dba00d08c28916939bbbbe08741
MD5 bfe80a9a34dce667d0b30189ecb29961
BLAKE2b-256 368fe3471cb9b7bd60514c43713a4cd00d6036015d64661435712c7996b7aa59

See more details on using hashes here.

File details

Details for the file step_decomposed_influence-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: step_decomposed_influence-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for step_decomposed_influence-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ae979b2297c2d8e581741242db98322becfb86a328151f05736059d6ff840d84
MD5 89ca29fe7fdb86700bb1a8b6ef3197f2
BLAKE2b-256 bc411fa92839b7886e1ba42b8bfc5a6984b13b3de6db4ed12f093c764787e0ef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page