Skip to main content

Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature

Project description

Kronfluence

License License CI Linting Ruff


Kronfluence is a PyTorch package designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For detailed description of the methodology, see the paper, Studying Large Language Model Generalization with Influence Functions.


Installation

[!IMPORTANT] Requirements:

  • Python: Version 3.9 or later
  • PyTorch: Version 2.1 or later

To install the latest stable version, use the following pip command:

pip install kronfluence

Alternatively, you can install directly from source:

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .

Getting Started

Kronfluence supports influence computations on nn.Linear and nn.Conv2d modules. See the Technical Documentation page for a comprehensive guide.

TL;DR You need to prepare a trained model and datasets, and pass them into the Analyzer class.

import torch
import torchvision
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model

# Define the model and load the trained model weights.
model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))

# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)
eval_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)

# Define the task. See the Technical Documentation page for details.
task = MnistTask()

# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
    scores_name="my_scores",
    factors_name="my_factors",
    query_dataset=eval_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")

Kronfluence supports various PyTorch features; the following table summarizes the supported features:

The examples folder contains several examples demonstrating how to use Kronfluence.

LogIX

While Kronfluence supports influence function computations on large-scale models like Meta-Llama-3-8B-Instruct, for those interested in running influence analysis on even larger models or with a large number of query data points, our project LogIX may be worth exploring. It integrates with frameworks like HuggingFace Trainer and PyTorch Lightning and is also compatible with many PyTorch features (DDP & FSDP & DeepSpeed).

Contributing

Contributions are welcome! To get started, please review our Code of Conduct. For bug fixes, please submit a pull request. If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.

Setting Up Development Environment

To contribute to Kronfluence, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"

Style Testing

To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using isort:

isort kronfluence

Format code using ruff:

ruff format kronfluence

To view all pylint complaints, run the following command:

pylint kronfluence

Please address any reported issues before submitting your PR.

Acknowledgements

Omkar Dige contributed to the profiling, DDP, and FSDP utilities, and Adil Asif provided valuable insights and suggestions on structuring the DDP and FSDP implementations. I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

License

This software is released under the Apache 2.0 License, as detailed in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kronfluence-1.0.0.tar.gz (115.5 kB view details)

Uploaded Source

Built Distribution

kronfluence-1.0.0-py3-none-any.whl (181.6 kB view details)

Uploaded Python 3

File details

Details for the file kronfluence-1.0.0.tar.gz.

File metadata

  • Download URL: kronfluence-1.0.0.tar.gz
  • Upload date:
  • Size: 115.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-1.0.0.tar.gz
Algorithm Hash digest
SHA256 315aa9c7a577b1e863e00e84ed2c0f9958b47fdc9f9586795c99430863f068c7
MD5 3c17ab3671b69932d03c648adf0fcd0a
BLAKE2b-256 a7ae3debf9c9d080b836a8db39e7f8fa820f47bfc8c5a3d469705cdc3fa52fb1

See more details on using hashes here.

File details

Details for the file kronfluence-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: kronfluence-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 181.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8a0d4fe08124b0562ed47997eead66e3fdce8d95ab230cb6f31ed716f62abaf5
MD5 d4644a94a42385e508da60a535dbe985
BLAKE2b-256 8e4619bd7ec84a4fc9f03341f93e891e395e9d3448f61ed20e8858c894018428

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page