Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature
Project description
Kronfluence is a PyTorch package designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For detailed description of the methodology, see the paper, Studying Large Language Model Generalization with Influence Functions.
Installation
[!IMPORTANT] Requirements:
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later
To install the latest stable version, use the following pip
command:
pip install kronfluence
Alternatively, you can install directly from source:
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .
Getting Started
Kronfluence supports influence computations on nn.Linear
and nn.Conv2d
modules.
See the Technical Documentation page for a comprehensive guide.
TL;DR You need to prepare a trained model and datasets, and pass them into the Analyzer
class.
import torch
import torchvision
from torch import nn
from kronfluence.analyzer import Analyzer, prepare_model
# Define the model and load the trained model weights.
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(784, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 1024, bias=True),
nn.ReLU(),
nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))
# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
eval_dataset = torchvision.datasets.MNIST(
root="./data",
download=True,
train=True,
)
# Define the task. See the Technical Documentation page for details.
task = MnistTask()
# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)
# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)
# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
scores_name="my_scores",
factors_name="my_factors",
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)
# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
Kronfluence supports various PyTorch features; the following table summarizes the supported features:
Feature | Supported |
---|---|
Distributed Data Parallel (DDP) | ✅ |
Automatic Mixed Precision (AMP) | ✅ |
Torch Compile | ✅ |
Gradient Checkpointing | ✅ |
Fully Sharded Data Parallel (FSDP) | ✅ |
The examples folder contains several examples demonstrating how to use Kronfluence.
LogIX
While Kronfluence supports influence function computations on large-scale models like Meta-Llama-3-8B-Instruct
, for those
interested in running influence analysis on even larger models or with a large number of query data points, our
project LogIX may be worth exploring. It integrates with frameworks like
HuggingFace Trainer and PyTorch Lightning
and is also compatible with many PyTorch features (DDP & FSDP & DeepSpeed).
Contributing
Contributions are welcome! To get started, please review our Code of Conduct. For bug fixes, please submit a pull request. If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.
Setting Up Development Environment
To contribute to Kronfluence, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.
git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"
Style Testing
To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.
Sort import orderings using isort:
isort kronfluence
Format code using ruff:
ruff format kronfluence
To view all pylint complaints, run the following command:
pylint kronfluence
Please address any reported issues before submitting your PR.
Acknowledgements
Omkar Dige contributed to the profiling, DDP, and FSDP utilities, and Adil Asif provided valuable insights and suggestions on structuring the DDP and FSDP implementations. I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
License
This software is released under the Apache 2.0 License, as detailed in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kronfluence-1.0.0.tar.gz
.
File metadata
- Download URL: kronfluence-1.0.0.tar.gz
- Upload date:
- Size: 115.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 315aa9c7a577b1e863e00e84ed2c0f9958b47fdc9f9586795c99430863f068c7 |
|
MD5 | 3c17ab3671b69932d03c648adf0fcd0a |
|
BLAKE2b-256 | a7ae3debf9c9d080b836a8db39e7f8fa820f47bfc8c5a3d469705cdc3fa52fb1 |
File details
Details for the file kronfluence-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: kronfluence-1.0.0-py3-none-any.whl
- Upload date:
- Size: 181.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8a0d4fe08124b0562ed47997eead66e3fdce8d95ab230cb6f31ed716f62abaf5 |
|
MD5 | d4644a94a42385e508da60a535dbe985 |
|
BLAKE2b-256 | 8e4619bd7ec84a4fc9f03341f93e891e395e9d3448f61ed20e8858c894018428 |