A library for data attribution tools and benchmarks.
Project description
dattri
: A Library for Efficient Data Attribution
dattri
is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms. You may use dattri
to
- Deploy existing data attribution methods to PyTorch models
- e.g., Influence Function, TracIn, RPS, TRAK, ...
- Develop new data attribution methods with efficient implementation of low-level utility functions
- e.g., Hessian (HVP/IHVP), Fisher Information Matrix (IFVP), random projection, dropout ensembling, ...
- Benchmark data attribution methods with standard benchmark settings
- e.g., MNIST-10+LR/MLP, CIFAR-10/2+ResNet-9, MAESTRO + Music Transformer, Shakespeare + nanoGPT, ...
Quick Start
Installation
pip install dattri
If you want to use all features on CUDA and accelerate the library, you may install the full version by
pip install dattri[all]
[!NOTE] It's highly recommended to use a device support CUDA to run
dattri
, especially for moderately large or larger models or datasets. And it's required to have CUDA if you want to install the full versiondattri
.
[!NOTE] If you are using
dattri[all]
, please usepip<23
andtorch<2.3
due to some known issue offast_jl
library.
Apply Data Attribution methods on PyTorch Models
One can apply different data attribution methods on PyTorch Models. One only needs to define:
- loss function used for model training (will be used as target function to be attributed if no other target function provided).
- trained model checkpoints.
- the data loaders for training samples and test samples (e.g.,
train_loader
,test_loader
). - (optional) target function to be attributed if it's not the same as loss function.
The following is an example to use IFAttributorCG
and AttributionTask
to apply data attribution to a PyTorch model.
from dattri.algorithm import IFAttributorCG
from dattri.task import AttributionTask
def f(params, data): # an example of loss function using CE loss
x, y = data
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, x)
return loss(yhat, y)
task = AttributionTask(loss_func=f,
model=model,
checkpoints=model.state_dict())
attributor = IFAttributorCG(
task=task,
**attributor_hyperparams # e.g., iter_num
)
attributor.cache(train_loader) # optional pre-processing to accelerate the attribution
score = attributor.attribute(train_loader, test_loader)
Use low-level utility functions to develop new data attribution methods
HVP/IHVP
Hessian-vector product (HVP), inverse-Hessian-vector product
(IHVP) are widely used in data attribution methods. dattri
provides efficient implementation to these operators by torch.func
. This example shows how to use the CG implementation of the IHVP implementation.
from dattri.func.hessian import ihvp_cg, ihvp_at_x_cg
def f(x, param):
return torch.sin(x / param).sum()
x = torch.randn(2)
param = torch.randn(1)
v = torch.randn(5, 2)
# ihvp_cg method
ihvp_func = ihvp_cg(f, argnums=0, max_iter=2) # argnums=0 indicates that the param of (x, param) to be passed to ihvp_func is the model parameter
ihvp_result_1 = ihvp_func((x, param), v) # both (x, param) and v as the inputs
# ihvp_at_x_cg method: (x, param) is cached
ihvp_at_x_func = ihvp_at_x_cg(f, x, param, argnums=0, max_iter=2)
ihvp_result_2 = ihvp_at_x_func(v) # only v as the input
# the above two will give the same result
assert torch.allclose(ihvp_result_1, ihvp_result_2)
Random Projection
It has been shown that long vectors will retain most of their relative information when projected down to a smaller feature dimension. To reduce the computational cost, random projection is widely used in data attribution methods. Following is an example to use random_project
. The implementation leaverges fast_jl
.
from dattri.func.random_projection import random_project
# initialize the projector based on users' needs
project_func = random_project(tensor, tensor.size(0), proj_dim=512)
# obtain projected tensors
projected_tensor = project_func(torch.full_like(tensor))
Normally speaking, tensor
is probably the gradient of loss/target function and has a large dimension (i.e., the number of parameters).
Dropout Ensemble
Recent studies found that ensemble methods can significantly improve the performance of data attribution, DROPOUT ENSEMBLE is one of these ensemble methods. One may prepare their model with
from dattri.model_utils.dropout import activate_dropout
# initialize a torch.nn.Module model
model = MLP()
# (option 1) activate all dropout layers
model = activate_dropout(model, dropout_prob=0.2)
# (option 2) activate specific dropout layers
# here "dropout1" and "dropout2" are the names of dropout layers within the model
model = activate_dropout(model, ["dropout1", "dropout2"], dropout_prob=0.2)
Algorithms Supported
Family | Algorithms |
---|---|
IF | Explicit |
CG | |
LiSSA | |
Arnoldi | |
DataInf | |
EK-FAC | |
TracIn | TracInCP |
Grad-Dot | |
Grad-Cos | |
RPS | RPS-L2 |
TRAK | TRAK |
Metrics Supported
- Leave-one-out (LOO) correlation
- Linear datamodeling score (LDS)
- Area under the ROC curve (AUC) for noisy label detection
Benchmark Settings Supported
Dataset | Model | Task | Sample size (train,test) | Parameter size | Metrics | Data Source |
---|---|---|---|---|---|---|
MNIST-10 | LR | Image Classification | (5000,500) | 7840 | LOO/LDS/AUC | link |
MNIST-10 | MLP | Image Classification | (5000,500) | 0.11M | LOO/LDS/AUC | link |
CIFAR-2 | ResNet-9 | Image Classification | (5000,500) | 4.83M | LDS | link |
CIFAR-10 | ResNet-9 | Image Classification | (5000,500) | 4.83M | AUC | link |
MAESTRO | Music Transformer | Music Generation | (5000,178) | 13.3M | LDS | link |
Shakespeare | nanoGPT | Text Generation | (3921,435) | 10.7M | LDS | link |
Benchmark Results
MNIST+LR/MLP
LDS performance on larger models
AUC performance
Development Plan
- More (larger) benchmark settings to come
- ImageNet + ResNet-18
- Tinystories + nanoGPT
- Comparison with other libraries
- More algorithms and low-level utility functions to come
- KNN filter
- TF-IDF filter
- RelativeIF
- KNN Shapley
- In-Run Shapley
- Better documentation
- Quick start colab notebooks
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file dattri-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: dattri-0.1.0-py3-none-any.whl
- Upload date:
- Size: 612.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6c788d4d0563555fbc6d1a0884709bad74770db4a7ef1cbe93ab00ee214776a4 |
|
MD5 | c6e1d82c3a06ceee30229efd4e542039 |
|
BLAKE2b-256 | 6aadc2ca70e974b1dfff2e4bef01e464543d79cabf181429097c21cd2401e632 |