A unified framework for attributing model components, data, and training dynamics to model behavior.
Project description
ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior
This repository is the official implementation of Grokking ExPLAIND: Unifying Model, Data, and Training Attribution to Study Model Behavior. To jump directly to the experiments, go to experiments/.
Requirements
Manual installation
We ran all our experiments in python version 3.12.7. You can use conda to create a fresh environment first, clone our repository, and install the necessary packages using pip.
conda create -n explaind python=3.12
conda activate explaind
git clone git@github.com:mainlp/explaind.git
pip install torch torchvision numpy tqdm tensorboard pandas
If you also want to recreate the plots shown in the paper, you additionally need the following packages:
pip install plotly umap_learn
Alternatively, you can also install from the requirements file:
pip install -r requirements.txt
PyPi installation
We also provide a PyPi package, which you can directly install with pip:
pip install explaind
Note, that this will only install the code contained in explaind/. To replicate our experiments, you will still need to clone this repository.
Training models with history
If you want to apply ExPLAIND to your own model, you need to retrain it tracking relevant parts of the training process by using the wrappers provided in this repository. Note that depending on model size full history tracking can become very expensive. We're currently working on a solution for allowing cheaper partial tracking. For example, the training process of the modulo addition model includes the following additions:
model = SingleLayerTransformerClassifier().to(device)
# wrap into path wrapper
model = ModelPath(model, device=device, checkpoint_path="model_checkpoint.pt")
loss_fct = RegularizedCrossEntropyLoss(alpha=alpha, p=reg_pow, device=device)
optimizer = AdamWOptimizerPath(model, checkpoint_path="optimizer_checkpoint.pt")
data_path = DataPath(train_loader, checkpoint_path=checkpoint_path + "data_checkpoint.pt", overwrite=True, full_batch=False)
for epoch in range(epochs):
for batch in data_path.dataloader:
x, y = data_path.get_batch(batch)
optimizer.zero_grad()
output = model.forward(x)
l, reg = loss_fct(output, y, params=model.parameters(), output_reg=True)
l.backward()
optimizer.step()
# log checkpoint values we need for epk prediction
model.log_checkpoint()
optimizer.log_checkpoint()
# save the checkpoints to disk at the locations defined above
# loading these later, we can compute the EPK reformulation of the model
optimizer.save_checkpoints()
model.save_checkpoints()
data_path.save_checkpoints()
For actual, executable training scripts, you can have a look at experiments/train_models/modulo_model.py and experiments/train_models/cifar2_model.py.
Getting EPK predictions and kernel accumulations
Once you have the history of the training run you want to explain, you can load them into the EPK module and compute the prediction of the reformulated model as follows:
epk = ExactPathKernelModel(
model=model, # wrapper from before
optimizer=optimizer, # wrapper from before
loss_fn=RegularizedCrossEntropyLoss(alpha=0.0),
data_path=data_path, # wrapper from before
integral_eps=0.01, # 1/eps = 1/0.01 = 100 integral steps
evaluate_predictions=True,
keep_param_wise_kernel=True,
param_wise_kernel_keep_out_dims=True,
)
# make batch size small enough so you don't run OOM
val_loader = torch.utils.data.DataLoader(val_loader.dataset, batch_size=100, shuffle=False)
preds = []
for i, (X, y) in enumerate(val_loader):
torch.cuda.empty_cache()
X = X.to(device)
y = y.to(device)
pred = epk.predict(X, y_test=y, keep_kernel_matrices=True)
preds.append((i, pred, y))
Note that there are different settings for which (accumulated) slices of the kernel to store during the prediction. Depending on your choices there, runtimes can vary greatly because of GPU I/O and extra matrix computations involved. For the complete respective valiadtion scripts, consider giving experiments/validate_epk/ a look.
Experiments, ablations, and plots
Besides further instructions on how to reproduce the experiments in our paper, the experiments/ folder contains all the scripts to run additional experiments, ablations, and generate plots. Any checkpoints, plots or other artifacts will be stored in results/ by default.
Contributing
We publish this repository under MIT license and welcome anybody who wants to contribute. If you have a question or an idea, feel free to reach out to Florian (feichin[at]cis[dot]lmu[dot]de) or simply start a pull request/issue.
If you want to use our code for your own projects, please consider citing our work:
TODO
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file explaind-0.0.2.tar.gz.
File metadata
- Download URL: explaind-0.0.2.tar.gz
- Upload date:
- Size: 25.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1dd25cec19991f527b5829a74c3c33c7bd013a7bb3c53d78ddde236c438a8352
|
|
| MD5 |
c9e5545e22870167eb2be74224081e23
|
|
| BLAKE2b-256 |
a75c8c1f439a16fee735ac09711db2265e43092314b3e7271c3538d266cd6ecf
|
File details
Details for the file explaind-0.0.2-py3-none-any.whl.
File metadata
- Download URL: explaind-0.0.2-py3-none-any.whl
- Upload date:
- Size: 25.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
750f07ec1a0291dec0dc3b49a8b7de9857bfaf516b35253ba5c3f3794e760def
|
|
| MD5 |
9d3fc996ffd7a1849e79939534869ff5
|
|
| BLAKE2b-256 |
de4466e06d55e5eb86ca1b79a92bbcf54f72276a15bc1e544d56889727b23b8f
|