A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support.
Project description
Centered Kernel Alignment (CKA) - PyTorch Implementation
A PyTorch implementation of Centered Kernel Alignment (CKA) with GPU support for fast and efficient computation.
[!WARNING] This project is for educational and academic purposes (and for fun 🤷🏻).
Features
- GPU Accelerated: Leverages the power of GPUs for significantly faster CKA calculations compared to NumPy-based implementations.
- On-the-Fly Calculation: Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.
- Easy to Use: Simple and intuitive API for calculating the CKA matrix between two models.
- Flexible: Can be used with any PyTorch models and dataloaders.
Installation
pip install cka-pytorch
Usage
import torch
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from cka_pytorch.cka import CKACalculator
# 1. Define your models and dataloader
model1 = resnet18(pretrained=True).cuda()
model2 = resnet18(pretrained=True).cuda() # Or a different model
# Create a dummy dataloader for demonstration
dummy_data = torch.randn(100, 3, 224, 224)
dummy_labels = torch.randint(0, 10, (100,))
dummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)
dataloader = DataLoader(dummy_dataset, batch_size=32)
# 2. Initialize the CKACalculator
# By default, we will calculate CKA across all layers of the two models
calculator = CKACalculator(
model1=model1,
model2=model2,
model1_name="ResNet18",
model2_name="ResNet18",
batched_feature_size=256,
verbose=True,
)
# 3. Calculate the CKA matrix
cka_matrix = calculator.calculate_cka_matrix(dataloader)
# 4. Plot the CKA Matrix as heatmap
calculator.plot_cka_matrix(title="CKA between ResNet18 and ResNet18")
Contributing
- If you find this repository helpful, please give it a :star:.
- If you encounter any bugs or have suggestions for improvements, feel free to open an issue.
- This implementation has been primarily tested with ResNet architectures.
Acknowledgement
This project is based on:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cka_pytorch-1.1.3.tar.gz.
File metadata
- Download URL: cka_pytorch-1.1.3.tar.gz
- Upload date:
- Size: 16.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1ca08b2fa414c8f273d4ac724732611d500db5dd826a8b0a6b36666c4af0ce05
|
|
| MD5 |
05781763fafa92cbea5db0682357671a
|
|
| BLAKE2b-256 |
3075d6c5e7ff18fdde717022b95549ce4a1c1c5e8be125d0d223cdc9bbe359cb
|
Provenance
The following attestation bundles were made for cka_pytorch-1.1.3.tar.gz:
Publisher:
publish.yml on datthinh1801/CKA-pytorch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cka_pytorch-1.1.3.tar.gz -
Subject digest:
1ca08b2fa414c8f273d4ac724732611d500db5dd826a8b0a6b36666c4af0ce05 - Sigstore transparency entry: 602059509
- Sigstore integration time:
-
Permalink:
datthinh1801/CKA-pytorch@48f56ef30582cdac69eddd4871697442d1d58fdc -
Branch / Tag:
refs/tags/v1.1.3 - Owner: https://github.com/datthinh1801
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@48f56ef30582cdac69eddd4871697442d1d58fdc -
Trigger Event:
release
-
Statement type:
File details
Details for the file cka_pytorch-1.1.3-py3-none-any.whl.
File metadata
- Download URL: cka_pytorch-1.1.3-py3-none-any.whl
- Upload date:
- Size: 18.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d70a23772c50a5e3fcc29561936ded5ab0a629167c88b12c2da8ba613d9281b9
|
|
| MD5 |
eeae45d3f53ae2784dbdb436cf341d3a
|
|
| BLAKE2b-256 |
41f6be736a66615e66f774e5677792628bfbcdb4b8ce4d1baffb63f8abe0cf90
|
Provenance
The following attestation bundles were made for cka_pytorch-1.1.3-py3-none-any.whl:
Publisher:
publish.yml on datthinh1801/CKA-pytorch
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cka_pytorch-1.1.3-py3-none-any.whl -
Subject digest:
d70a23772c50a5e3fcc29561936ded5ab0a629167c88b12c2da8ba613d9281b9 - Sigstore transparency entry: 602059511
- Sigstore integration time:
-
Permalink:
datthinh1801/CKA-pytorch@48f56ef30582cdac69eddd4871697442d1d58fdc -
Branch / Tag:
refs/tags/v1.1.3 - Owner: https://github.com/datthinh1801
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@48f56ef30582cdac69eddd4871697442d1d58fdc -
Trigger Event:
release
-
Statement type: