Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature
Project description
Kronfluence is a repository designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For a detailed description of the methodology, see the paper Studying Large Language Model Generalization with Influence Functions.
[!WARNING] This repository is under active development and has not reached its first stable release.
Installation
[!IMPORTANT] Requirements:
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later
To install the latest version, use the following pip
command:
pip install kronfluence
Alternatively, you can install the library directly from the source:
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .
Getting Started
Kronfluence supports influence computations on nn.Linear
and nn.Conv2d
modules. See the Technical Documentation
page for a comprehensive guide,
Examples
The examples folder contains several examples on how to use Kronfluence.
We plan to add more language model examples. TL;DR You need to prepare the trained model and datasets, and pass them into the Analyzer
.
import torch
import torchvision
from torch import nn
from kronfluence.analyzer import Analyzer, prepare_model
# Define the model and load the trained model weights.
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(784, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))
# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
eval_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
# Initialize the task with relevant loss and measurement.
task = MnistTask()
# Prepare the model for influence computation with the specified task.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)
# Fit all EKFAC factors for the given model on the training dataset.
analyzer.fit_all_factors(factors_name="ekfac", dataset=train_dataset)
# Compute all pairwise influence scores using the computed factors.
analyzer.compute_pairwise_scores(
scores_name="pairwise_scores",
factors_name="ekfac",
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)
# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="pairwise_scoeres")
Contributing
Your contributions are welcome! For bug fixes, please submit a pull request without prior discussion. For proposing new features, examples, or extensions, kindly start a discussion through an issue before proceeding.
Setting Up Development Environment
To contribute, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"
License
This software is released under the Apache 2.0 License, as detailed in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for kronfluence-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 390e40e39c8fc18deac61fa4729501447c249e84f0e5dd64e637c93901d590dd |
|
MD5 | 13950f68951f08c5bfcb5d49e04dd9e6 |
|
BLAKE2b-256 | ad7264beba91f91b19e46605052f85665531d0cc2c0f94b9f2db14ce9905ffdd |