Skip to main content

A unified framework for attributing model components, data, and training dynamics to model behavior.

Project description

PyPI - Python PyPI - License PyPI - PyPi arXiv

ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior

Screenshot 2025-05-21 at 08 44 04

This repository is the official implementation of Grokking ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior. To jump directly to the experiments, go to experiments/.

Requirements

Manual installation

We ran all our experiments in python version 3.12.7. You can use conda to create a fresh environment first, clone our repository, and install the necessary packages using pip.

conda create -n explaind python=3.12
conda activate explaind

git clone git@github.com:mainlp/explaind.git

pip install torch torchvision numpy tqdm tensorboard pandas

If you also want to recreate the plots shown in the paper, you additionally need the following packages:

pip install plotly umap_learn

Alternatively, you can also install from the requirements file:

pip install -r requirements.txt

PyPi installation

We also provide a PyPi package, which you can directly install with pip:

pip install explaind

Note, that this will only install the code contained in explaind/. To replicate our experiments, you will still need to clone this repository.

Training models with history

If you want to apply ExPLAIND to your own model, you need to retrain it tracking relevant parts of the training process by using the wrappers provided in this repository. Note that depending on model size full history tracking can become very expensive. We're currently working on a solution for allowing cheaper partial tracking. For example, the training process of the modulo addition model includes the following additions:

model = SingleLayerTransformerClassifier().to(device)
# wrap into path wrapper
model = ModelPath(model, device=device, checkpoint_path="model_checkpoint.pt")
loss_fct = RegularizedCrossEntropyLoss(alpha=alpha, p=reg_pow, device=device)
optimizer = AdamWOptimizerPath(model, checkpoint_path="optimizer_checkpoint.pt")
data_path = DataPath(train_loader, checkpoint_path=checkpoint_path + "data_checkpoint.pt", overwrite=True, full_batch=False)

for epoch in range(epochs):
    for batch in data_path.dataloader:
        x, y = data_path.get_batch(batch)
        optimizer.zero_grad()
        output = model.forward(x)
        l, reg = loss_fct(output, y, params=model.parameters(), output_reg=True)
        l.backward()
        optimizer.step()

        # log checkpoint values we need for epk prediction
        model.log_checkpoint() 
        optimizer.log_checkpoint()

# save the checkpoints to disk at the locations defined above
# loading these later, we can compute the EPK reformulation of the model
optimizer.save_checkpoints()
model.save_checkpoints()
data_path.save_checkpoints()

For actual, executable training scripts, you can have a look at experiments/train_models/modulo_model.py and experiments/train_models/cifar2_model.py.

Getting EPK predictions and kernel accumulations

Once you have the history of the training run you want to explain, you can load them into the EPK module and compute the prediction of the reformulated model as follows:

epk = ExactPathKernelModel(
    model=model,  # wrapper from before
    optimizer=optimizer,  # wrapper from before
    loss_fn=RegularizedCrossEntropyLoss(alpha=0.0),
    data_path=data_path,  # wrapper from before
    integral_eps=0.01,  # 1/eps = 1/0.01 = 100 integral steps
    evaluate_predictions=True,
    keep_param_wise_kernel=True,
    param_wise_kernel_keep_out_dims=True,
)

# make batch size small enough so you don't run OOM
val_loader = torch.utils.data.DataLoader(val_loader.dataset, batch_size=100, shuffle=False)

preds = []
for i, (X, y) in enumerate(val_loader):
    torch.cuda.empty_cache()
    X = X.to(device)
    y = y.to(device)
    pred = epk.predict(X, y_test=y, keep_kernel_matrices=True)
    preds.append((i, pred, y))

Note that there are different settings for which (accumulated) slices of the kernel to store during the prediction. Depending on your choices there, runtimes can vary greatly because of GPU I/O and extra matrix computations involved. For the complete respective valiadtion scripts, consider giving experiments/validate_epk/ a look.

Experiments, ablations, and plots

Besides further instructions on how to reproduce the experiments in our paper, the experiments/ folder contains all the scripts to run additional experiments, ablations, and generate plots. Any checkpoints, plots or other artifacts will be stored in results/ by default.

Contributing

We publish this repository under MIT license and welcome anybody who wants to contribute. If you have a question or an idea, feel free to reach out to Florian (feichin[at]cis[dot]lmu[dot]de) or simply start a pull request/issue.

If you want to use our code for your own projects, please consider citing our work:

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

explaind-0.0.2.tar.gz (25.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

explaind-0.0.2-py3-none-any.whl (25.6 kB view details)

Uploaded Python 3

File details

Details for the file explaind-0.0.2.tar.gz.

File metadata

  • Download URL: explaind-0.0.2.tar.gz
  • Upload date:
  • Size: 25.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for explaind-0.0.2.tar.gz
Algorithm Hash digest
SHA256 1dd25cec19991f527b5829a74c3c33c7bd013a7bb3c53d78ddde236c438a8352
MD5 c9e5545e22870167eb2be74224081e23
BLAKE2b-256 a75c8c1f439a16fee735ac09711db2265e43092314b3e7271c3538d266cd6ecf

See more details on using hashes here.

File details

Details for the file explaind-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: explaind-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 25.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for explaind-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 750f07ec1a0291dec0dc3b49a8b7de9857bfaf516b35253ba5c3f3794e760def
MD5 9d3fc996ffd7a1849e79939534869ff5
BLAKE2b-256 de4466e06d55e5eb86ca1b79a92bbcf54f72276a15bc1e544d56889727b23b8f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page