Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature
Project description
Kronfluence is a research repository designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For a detailed description of the methodology, see the paper Studying Large Language Model Generalization with Influence Functions.
[!WARNING] This repository is under active development and has not reached its first stable release.
Installation
[!IMPORTANT] Requirements:
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later
To install the latest version, use the following pip
command:
pip install kronfluence
Alternatively, you can install the library directly from the source:
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .
Getting Started
Kronfluence supports influence computations on nn.Linear
and nn.Conv2d
modules. See the Technical Documentation
page for a comprehensive guide.
Learn More
The examples folder contains several examples.
We plan to add more examples in the future. TL;DR You need to prepare the trained model and datasets, and pass them into Analyzer
.
import torch
import torchvision
from torch import nn
from kronfluence.analyzer import Analyzer, prepare_model
# Define the model and load the trained model weights.
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(784, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))
# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
eval_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
# Define the task.
task = MnistTask()
# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)
# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)
# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
scores_name="my_scores",
factors_name="my_factors",
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)
# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
Contributing
Your contributions are welcome! For bug fixes, please submit a pull request without prior discussion. For proposing new features, examples, or extensions, kindly start a discussion through an issue before proceeding.
Setting Up Development Environment
To contribute, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"
License
This software is released under the Apache 2.0 License, as detailed in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kronfluence-0.0.2.tar.gz
.
File metadata
- Download URL: kronfluence-0.0.2.tar.gz
- Upload date:
- Size: 81.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | caef89fa75e8d8a8793f38172f4a34dd40ce38e9aa96985ff8d10ec0d2c8bf29 |
|
MD5 | 4761b10dcbaca6dbd84087867f0dba3e |
|
BLAKE2b-256 | 1dc2cec2dca3ff81ff1a404dfcbb8871a2ba165334422cdcf5ec717d287051e8 |
File details
Details for the file kronfluence-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: kronfluence-0.0.2-py3-none-any.whl
- Upload date:
- Size: 126.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b83e085d81b3d46ba21b86e0e4516176eb13d10878e761e06cff9de444274f09 |
|
MD5 | c2d7e8ee58d0f1e0e6d1fe47ea0c3f45 |
|
BLAKE2b-256 | 505c6e345c78568c52b9ba7b752b44011489cd391ab9bab0722385a57bfcc0c5 |