Skip to main content

Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature

Project description

Kronfluence

License License CI Linting Ruff


Kronfluence is a PyTorch package designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For detailed description of the methodology, see the paper, Studying Large Language Model Generalization with Influence Functions.


Installation

[!IMPORTANT] Requirements:

  • Python: Version 3.9 or later
  • PyTorch: Version 2.1 or later

To install the latest stable version, use the following pip command:

pip install kronfluence

Alternatively, you can install directly from source:

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .

Getting Started

Kronfluence supports influence computations on nn.Linear and nn.Conv2d modules. See the Technical Documentation page for a comprehensive guide.

TL;DR You need to prepare a trained model and datasets, and pass them into the Analyzer class.

import torch
import torchvision
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model

# Define the model and load the trained model weights.
model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))

# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)
eval_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)

# Define the task. See the Technical Documentation page for details.
task = MnistTask()

# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
    scores_name="my_scores",
    factors_name="my_factors",
    query_dataset=eval_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")

Kronfluence supports various PyTorch features; the following table summarizes the supported features:

The examples folder contains several examples demonstrating how to use Kronfluence.

LogIX

While Kronfluence supports influence function computations on large-scale models like Meta-Llama-3-8B-Instruct, for those interested in running influence analysis on even larger models or with a large number of query data points, our project LogIX may be worth exploring. It integrates with frameworks like HuggingFace Trainer and PyTorch Lightning and is also compatible with many PyTorch features (DDP & FSDP & DeepSpeed).

Contributing

Contributions are welcome! To get started, please review our Code of Conduct. For bug fixes, please submit a pull request. If you would like to propose new features or extensions, we kindly request that you open an issue first to discuss your ideas.

Setting Up Development Environment

To contribute to Kronfluence, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"

Style Testing

To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using isort:

isort kronfluence

Format code using ruff:

ruff format kronfluence

To view all pylint complaints, run the following command:

pylint kronfluence

Please address any reported issues before submitting your PR.

Acknowledgements

Omkar Dige contributed to the profiling, DDP, and FSDP utilities, and Adil Asif provided valuable insights and suggestions on structuring the DDP and FSDP implementations. I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

License

This software is released under the Apache 2.0 License, as detailed in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kronfluence-1.0.1.tar.gz (116.4 kB view details)

Uploaded Source

Built Distribution

kronfluence-1.0.1-py3-none-any.whl (182.9 kB view details)

Uploaded Python 3

File details

Details for the file kronfluence-1.0.1.tar.gz.

File metadata

  • Download URL: kronfluence-1.0.1.tar.gz
  • Upload date:
  • Size: 116.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-1.0.1.tar.gz
Algorithm Hash digest
SHA256 52758eb71f98764c9b3759ab52932adfe2b05805faabfcf4bc236dc56157daa2
MD5 dc35448ec6695f701282ff48da44651e
BLAKE2b-256 53a73291c79507d1cec93cabfaeff40ba227d0491c5b6fca3934c06dbcb597ba

See more details on using hashes here.

File details

Details for the file kronfluence-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: kronfluence-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 182.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 79d2a6d0b211c51296a9b33ecb9fb41edd1f2ca23f9bdff9cb9fce9749aa62e1
MD5 5c5d17551110f264867b2bbf917766b1
BLAKE2b-256 f93fb428c9a51205b2e8671c02dd7e950bff744f5ca520db0da8b120d653e926

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page