Skip to main content

A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.

Project description

Centered Kernel Alignment (CKA) - PyTorch Implementation

A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support for fast and efficient computation.

[!WARNING] This project is for educational and academic purposes (and for fun 🤷🏻).

Features

  • GPU Accelerated: Leverages the power of GPUs for significantly faster CKA calculations compared to NumPy-based implementations.
  • On-the-Fly Calculation: Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.
  • Easy to Use: Simple and intuitive API for calculating the CKA matrix between two models.
  • Flexible: Can be used with any PyTorch models and dataloaders.

Installation

pip install cka-pytorch

Usage

import torch

from torchvision.models import resnet18
from torch.utils.data import DataLoader

from cka_pytorch.cka import CKACalculator


# 1. Define your models and dataloader
model1 = resnet18(pretrained=True).cuda()
model2 = resnet18(pretrained=True).cuda() # Or a different model

# Create a dummy dataloader for demonstration
dummy_data = torch.randn(100, 3, 224, 224)
dummy_labels = torch.randint(0, 10, (100,))
dummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)
dataloader = DataLoader(dummy_dataset, batch_size=32)

# 2. Initialize the CKACalculator
# By default, we will calculate CKA across all layers of the two models
calculator = CKACalculator(
    model1=model1,
    model2=model2,
    model1_name="ResNet18",
    model2_name="ResNet18",
    batched_feature_size=256,
    verbose=True,
)

# 3. Calculate the CKA matrix
cka_matrix = calculator.calculate_cka_matrix(dataloader)

# 4. Plot the CKA Matrix as heatmap
calculator.plot_cka_matrix(title="CKA between ResNet18 and ResNet18")

Contributing

  • If you find this repository helpful, please give it a :star:.
  • If you encounter any bugs or have suggestions for improvements, feel free to open an issue.
  • This implementation has been primarily tested with ResNet architectures.

Acknowledgement

This project is based on:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cka_pytorch-1.1.3.tar.gz (16.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cka_pytorch-1.1.3-py3-none-any.whl (18.2 kB view details)

Uploaded Python 3

File details

Details for the file cka_pytorch-1.1.3.tar.gz.

File metadata

  • Download URL: cka_pytorch-1.1.3.tar.gz
  • Upload date:
  • Size: 16.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cka_pytorch-1.1.3.tar.gz
Algorithm Hash digest
SHA256 1ca08b2fa414c8f273d4ac724732611d500db5dd826a8b0a6b36666c4af0ce05
MD5 05781763fafa92cbea5db0682357671a
BLAKE2b-256 3075d6c5e7ff18fdde717022b95549ce4a1c1c5e8be125d0d223cdc9bbe359cb

See more details on using hashes here.

Provenance

The following attestation bundles were made for cka_pytorch-1.1.3.tar.gz:

Publisher: publish.yml on datthinh1801/CKA-pytorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cka_pytorch-1.1.3-py3-none-any.whl.

File metadata

  • Download URL: cka_pytorch-1.1.3-py3-none-any.whl
  • Upload date:
  • Size: 18.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cka_pytorch-1.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d70a23772c50a5e3fcc29561936ded5ab0a629167c88b12c2da8ba613d9281b9
MD5 eeae45d3f53ae2784dbdb436cf341d3a
BLAKE2b-256 41f6be736a66615e66f774e5677792628bfbcdb4b8ce4d1baffb63f8abe0cf90

See more details on using hashes here.

Provenance

The following attestation bundles were made for cka_pytorch-1.1.3-py3-none-any.whl:

Publisher: publish.yml on datthinh1801/CKA-pytorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page