Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature
Project description
Kronfluence is a research repository designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For a detailed description of the methodology, see the paper Studying Large Language Model Generalization with Influence Functions.
[!WARNING] This repository is under active development and has not reached its first stable release.
Installation
[!IMPORTANT] Requirements:
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later
To install the latest version, use the following pip
command:
pip install kronfluence
Alternatively, you can install the library directly from the source:
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .
Getting Started
Kronfluence supports influence computations on nn.Linear
and nn.Conv2d
modules. See the Technical Documentation
page for a comprehensive guide.
Learn More
The examples folder contains several examples.
We plan to add more examples in the future. TL;DR You need to prepare the trained model and datasets, and pass them into Analyzer
.
import torch
import torchvision
from torch import nn
from kronfluence.analyzer import Analyzer, prepare_model
# Define the model and load the trained model weights.
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(784, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))
# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
eval_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
# Define the task.
task = MnistTask()
# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)
# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)
# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
scores_name="my_scores",
factors_name="my_factors",
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)
# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
Contributing
Your contributions are welcome! For bug fixes, please submit a pull request without prior discussion. For proposing new features, examples, or extensions, kindly start a discussion through an issue before proceeding.
Setting Up Development Environment
To contribute, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"
License
This software is released under the Apache 2.0 License, as detailed in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for kronfluence-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b83e085d81b3d46ba21b86e0e4516176eb13d10878e761e06cff9de444274f09 |
|
MD5 | c2d7e8ee58d0f1e0e6d1fe47ea0c3f45 |
|
BLAKE2b-256 | 505c6e345c78568c52b9ba7b752b44011489cd391ab9bab0722385a57bfcc0c5 |