Skip to main content

Tracing the memory of neural nets with data attribution

Project description

Bergson

This library enables you to trace the memory of deep neural nets with gradient-based data attribution techniques. We currently focus on TrackStar, as described in Scalable Influence and Fact Tracing for Large Language Model Pretraining by Chang et al. (2024), and also include support for several alternative influence functions. We plan to add support for Magic soon.

We view attribution as a counterfactual question: If we "unlearned" this training sample, how would the model's behavior change? This formulation ties attribution to some notion of what it means to "unlearn" a training sample. Here we focus on a very simple notion of unlearning: taking a gradient ascent step on the loss with respect to the training sample.

Core features

  • Gradient store for serial queries. We provide collection-time gradient compression for efficient storage, and integrate with FAISS for fast KNN search over large stores.
  • On-the-fly queries. Query gradients without disk I/O overhead via a single pass over a dataset with a set of precomputed query gradients.
    • Experiment with multiple query strategies based on LESS.
    • Ideal for compression-free gradients.
  • Train‑time gradient collection. Capture gradients produced during training with a ~17% performance overhead.
  • Scalable. We use FSDP2, BitsAndBytes, and other performance optimizations to support large models, datasets, and clusters.
  • Integrated with HuggingFace Transformers and Datasets. We also support on-disk datasets in a variety of formats.
  • Structured gradient views and per-attention head gradient collection. Bergson enables mechanistic interpretability via easy access to per‑module or per-attention head gradients.

Announcements

January 2026

  • [Experimental] Support distributing preconditioners across nodes and devices for VRAM-efficient computation through the GradientCollectorWithDistributedPreconditioners. If you would like this functionality exposed via the CLI please get in touch! https://github.com/EleutherAI/bergson/pull/100

October 2025

September 2025

Installation

pip install bergson

Quickstart

bergson build runs/quickstart --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation --token_batch_size 4096

Usage

There are two ways to use Bergson. The first is to write an index of dataset gradients to disk using build then query it programmatically or using the Attributor or query CLI. The second is to specify your query upfront, then map over the dataset and collect and process gradients on the fly. When using this second strategy only influence scores will be saved.

You can build an index of gradients for each training sample from the command line, using bergson as a CLI tool:

bergson build <output_path> --model <model_name> --dataset <dataset_name>

This will create a directory at <output_path> containing the gradients for each training sample in the specified dataset. The --model and --dataset arguments should be compatible with the Hugging Face transformers library. By default it assumes that the dataset has a text column, but you can specify other columns using --prompt_column and optionally --completion_column. The --help flag will show you all available options.

You can also use the library programmatically to build an index. The collect_gradients function is just a bit lower level the CLI tool, and allows you to specify the model and dataset directly as arguments. The result is a HuggingFace dataset which contains a handful of new columns, including gradients, which contains the gradients for each training sample. You can then use this dataset to compute attributions.

At the lowest level of abstraction, the GradientCollector context manager allows you to efficiently collect gradients for each individual example in a batch during a backward pass, simultaneously randomly projecting the gradients to a lower-dimensional space to save memory. If you use Adafactor normalization we will do this in a very compute-efficient way which avoids computing the full gradient for each example before projecting it to the lower dimension. There are two main ways you can use GradientCollector:

  1. Using a closure argument, which enables you to make use of the per-example gradients immediately after they are computed, during the backward pass. If you're computing summary statistics or other per-example metrics, this is the most efficient way to do it.
  2. Without a closure argument, in which case the gradients are collected and returned as a dictionary mapping module names to batches of gradients. This is the simplest and most flexible approach but is a bit more memory-intensive.

On-the-fly Query

You can score a large dataset against a previously built query index without saving its gradients to disk:

bergson score <output_path> --model <model_name> --dataset <dataset_name> --query_path <existing_index_path> --score mean

We provide a utility to reduce a dataset into its mean or sum query gradient, for use as a query index:

bergson reduce <output_path> --model <model_name> --dataset <dataset_name> --method mean --unit_normalize

Index Query

We provide a query Attributor which supports unit normalized gradients and KNN search out of the box. Access it via CLI with

bergson query --index  <index_path> --model <model_name> --unit_norm

or programmatically with

from bergson import Attributor, FaissConfig

attr = Attributor(args.index, device="cuda")

...
query_tokens = tokenizer(query, return_tensors="pt").to("cuda:0")["input_ids"]

# Query the index
with attr.trace(model.base_model, 5) as result:
    model(query_tokens, labels=query_tokens).loss.backward()
    model.zero_grad()

To efficiently query on-disk indexes, perform ANN searches, and explore many other scalability features add a FAISS config:

attr = Attributor(args.index, device="cuda", faiss_cfg=FaissConfig("IVF1,SQfp16", mmap_index=True))

with attr.trace(model.base_model, 5) as result:
    model(query_tokens, labels=query_tokens).loss.backward()
    model.zero_grad()

Training Gradients

Gradient collection during training is supported via an integration with HuggingFace's Trainer and SFTTrainer classes. Training gradients are saved in the original order corresponding to their dataset items, and when the track_order flag is set the training steps associated with each training item are separately saved.

from bergson import GradientCollectorCallback, prepare_for_gradient_collection

callback = GradientCollectorCallback(
    path="runs/example",
    track_order=True,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    callbacks=[callback],
)
trainer = prepare_for_gradient_collection(trainer)
trainer.train()

Attention Head Gradients

By default Bergson collects gradients for named parameter matrices, but per-attention head gradients may be collected by configuring an AttentionConfig for each module of interest.

from bergson import AttentionConfig, IndexConfig, DataConfig
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("RonenEldan/TinyStories-1M", trust_remote_code=True, use_safetensors=True)

collect_gradients(
    model=model,
    data=data,
    processor=processor,
    path="runs/split_attention",
    attention_cfgs={
        # Head configuration for the TinyStories-1M transformer
        "h.0.attn.attention.out_proj": AttentionConfig(num_heads=16, head_size=4, head_dim=2),
    },
)

GRPO

Where a reward signal is available we compute gradients using a weighted advantage estimate based on Dr. GRPO:

bergson build <output_path> --model <model_name> --dataset <dataset_name> --reward_column <reward_column_name>

Development

pip install -e ".[dev]"
pre-commit install
pytest

We use conventional commits for releases.

Support

If you have suggestions, questions, or would like to collaborate, please email lucia@eleuther.ai or drop us a line in the #data-attribution channel of the EleutherAI Discord!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bergson-0.5.1.tar.gz (88.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bergson-0.5.1-py3-none-any.whl (97.9 kB view details)

Uploaded Python 3

File details

Details for the file bergson-0.5.1.tar.gz.

File metadata

  • Download URL: bergson-0.5.1.tar.gz
  • Upload date:
  • Size: 88.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bergson-0.5.1.tar.gz
Algorithm Hash digest
SHA256 9d18b4477f1a0edc6483b5cc695f9631d0872ea7efb075cfad60373a4d3f5ad8
MD5 b88b452b6d9ef52b26f978c7396491cf
BLAKE2b-256 0581d11898efa1d20d8578ef9fe6455496424a06db7fd1ac5f6e45dc346853f0

See more details on using hashes here.

Provenance

The following attestation bundles were made for bergson-0.5.1.tar.gz:

Publisher: build.yml on EleutherAI/bergson

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file bergson-0.5.1-py3-none-any.whl.

File metadata

  • Download URL: bergson-0.5.1-py3-none-any.whl
  • Upload date:
  • Size: 97.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for bergson-0.5.1-py3-none-any.whl
Algorithm Hash digest
SHA256 736e33dc4df51d7fe5f73644ee393cc9a6b4e24628c86615e3c105c496f2cb07
MD5 538e360c6ad980c1bac8bdc774f44178
BLAKE2b-256 74b6534952f0c13be44f8bce82885c8322be5d2c6a716fb8837d746a47bb2c30

See more details on using hashes here.

Provenance

The following attestation bundles were made for bergson-0.5.1-py3-none-any.whl:

Publisher: build.yml on EleutherAI/bergson

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page