Skip to main content

Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature

Project description

Kronfluence

License License CI Linting Ruff


Kronfluence is a research repository designed to compute influence functions using Kronecker-factored Approximate Curvature (KFAC) or Eigenvalue-corrected KFAC (EKFAC). For a detailed description of the methodology, see the paper Studying Large Language Model Generalization with Influence Functions.


[!WARNING] This repository is under active development and has not reached its first stable release.

Installation

[!IMPORTANT] Requirements:

  • Python: Version 3.9 or later
  • PyTorch: Version 2.1 or later

To install the latest version, use the following pip command:

pip install kronfluence

Alternatively, you can install the library directly from the source:

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e .

Getting Started

Kronfluence supports influence computations on nn.Linear and nn.Conv2d modules. See the Technical Documentation page for a comprehensive guide.

Learn More

The examples folder contains several examples. We plan to add more examples in the future. TL;DR You need to prepare the trained model and datasets, and pass them into Analyzer.

import torch
import torchvision
from torch import nn

from kronfluence.analyzer import Analyzer, prepare_model

# Define the model and load the trained model weights.
model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 1024, bias=True),
    nn.ReLU(),
    nn.Linear(1024, 10, bias=True),
)
model.load_state_dict(torch.load("model_path.pth"))

# Load the dataset.
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)
eval_dataset = torchvision.datasets.MNIST(
    root="./data",
    download=True,
    train=True,
)

# Define the task.
task = MnistTask()

# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
    scores_name="my_scores",
    factors_name="my_factors",
    query_dataset=eval_dataset,
    train_dataset=train_dataset,
    per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="my_scores")

Contributing

Your contributions are welcome! For bug fixes, please submit a pull request without prior discussion. For proposing new features, examples, or extensions, kindly start a discussion through an issue before proceeding.

Setting Up Development Environment

To contribute, you will need to set up a development environment on your machine. This setup includes installing all the dependencies required for linting and testing.

git clone https://github.com/pomonam/kronfluence.git
cd kronfluence
pip install -e ."[dev]"

License

This software is released under the Apache 2.0 License, as detailed in the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kronfluence-0.0.2.tar.gz (81.4 kB view details)

Uploaded Source

Built Distribution

kronfluence-0.0.2-py3-none-any.whl (126.8 kB view details)

Uploaded Python 3

File details

Details for the file kronfluence-0.0.2.tar.gz.

File metadata

  • Download URL: kronfluence-0.0.2.tar.gz
  • Upload date:
  • Size: 81.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-0.0.2.tar.gz
Algorithm Hash digest
SHA256 caef89fa75e8d8a8793f38172f4a34dd40ce38e9aa96985ff8d10ec0d2c8bf29
MD5 4761b10dcbaca6dbd84087867f0dba3e
BLAKE2b-256 1dc2cec2dca3ff81ff1a404dfcbb8871a2ba165334422cdcf5ec717d287051e8

See more details on using hashes here.

File details

Details for the file kronfluence-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: kronfluence-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 126.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for kronfluence-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b83e085d81b3d46ba21b86e0e4516176eb13d10878e761e06cff9de444274f09
MD5 c2d7e8ee58d0f1e0e6d1fe47ea0c3f45
BLAKE2b-256 505c6e345c78568c52b9ba7b752b44011489cd391ab9bab0722385a57bfcc0c5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page