Centered Kernel Alignment (CKA) with Efficient Computation and Layer-wise Visualization for PyTorch
Project description
pytorch-cka
The Fastest, Memory-efficient Python Library for computing layer-wise similarity between neural network models
44x faster CKA computation across 18 representational layers of ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs
- ⚡️ Fastest among CKA libraries thanks to vectorized ops & GPU acceleration
- 📦 Efficient memory management with explicit deallocation
- 🧠 Supports HuggingFace models, DataParallel, and DDP
- 🎨 Customizable visualizations: heatmaps and line charts
📦 Installation
Requires Python 3.10+
# Using pip
pip install pytorch-cka
# Using uv
uv add pytorch-cka
👟 Quick Start
Basic Usage
from cka import compute_cka
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34
resnet_18 = resnet18(pretrained=True)
resnet_34 = resnet34(pretrained=True)
dataloader1 = Dataloader(your_dataset1, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader2 = Dataloader(your_dataset2, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader3 = Dataloader(your_dataset3, batch_size=bach_size, shuffle=False, num_workers=4)
dataloaders = [dataloader1, dataloader2, dataloader3]
layers = [
'conv1',
'layer1.0.conv1',
'layer2.0.conv1',
'layer3.0.conv1',
'layer4.0.conv1',
'fc',
]
cka_matrices = compute_cka(
resnet_18,
resnet_34,
dataloaders,
layers=layers,
device=device,
)
for cka_matrix in cka_matrices:
print(cka_matrix)
Visualization
Heatmap
from cka import plot_cka_heatmap
fig, ax = plot_cka_heatmap(
cka_matrix,
layers1=layers,
layers2=layers,
model1_name="ResNet-18 (pretrained)",
model2_name="ResNet-18 (random init)",
annot=False, # Show values in cells
cmap="inferno", # Colormap
)
|
|
|
| Self-comparison | Cross-model |
Trend Plot
from cka import plot_cka_trend
# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
fig, ax = plot_cka_trend(
layer_trends,
x_values=epochs,
labels=RESNET18_LAYERS,
markers=['o'],
xlabel='Epoch',
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Epochs (ResNet-18)',
legend=True,
)
fig, ax = plot_cka_layer_trend(
cka_matrices,
layers=RESNET18_LAYERS,
labels=cka_loader_names,
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Layers (ResNet-18)',
legend=True,
)
| CKA Score Trend Across Epochs | CKA Score Trend Across Layers |
📚 References
Kornblith, Simon, et al. "Similarity of Neural Network Representations Revisited." ICML 2019.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch_cka-1.0.1.tar.gz.
File metadata
- Download URL: pytorch_cka-1.0.1.tar.gz
- Upload date:
- Size: 264.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
486504f8bda401173846d24462640c24659ed0f8f23984152c620b6b3bd9853d
|
|
| MD5 |
81f28b483465f1d5083989397b145365
|
|
| BLAKE2b-256 |
4318c1d1144781d567b0a98a89cac37923d94ead11f104ea61da3c7257307e02
|
Provenance
The following attestation bundles were made for pytorch_cka-1.0.1.tar.gz:
Publisher:
publish.yaml on ryusudol/Centered-Kernel-Alignment
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pytorch_cka-1.0.1.tar.gz -
Subject digest:
486504f8bda401173846d24462640c24659ed0f8f23984152c620b6b3bd9853d - Sigstore transparency entry: 896054152
- Sigstore integration time:
-
Permalink:
ryusudol/Centered-Kernel-Alignment@7d718a56af494e96161cbc72791dbc7e9d3db5d0 -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/ryusudol
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7d718a56af494e96161cbc72791dbc7e9d3db5d0 -
Trigger Event:
release
-
Statement type:
File details
Details for the file pytorch_cka-1.0.1-py3-none-any.whl.
File metadata
- Download URL: pytorch_cka-1.0.1-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e247ca79016ab07e0f72c05f78f55b477a67b92befd2acfc6f55f433dc156e34
|
|
| MD5 |
3bf70d2c5e4a69eac6eb259320a16f13
|
|
| BLAKE2b-256 |
93cb1e4973264157862c2b6331e092377d35d23fb48efe994e3acb6230c439ba
|
Provenance
The following attestation bundles were made for pytorch_cka-1.0.1-py3-none-any.whl:
Publisher:
publish.yaml on ryusudol/Centered-Kernel-Alignment
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pytorch_cka-1.0.1-py3-none-any.whl -
Subject digest:
e247ca79016ab07e0f72c05f78f55b477a67b92befd2acfc6f55f433dc156e34 - Sigstore transparency entry: 896054154
- Sigstore integration time:
-
Permalink:
ryusudol/Centered-Kernel-Alignment@7d718a56af494e96161cbc72791dbc7e9d3db5d0 -
Branch / Tag:
refs/tags/v1.0.1 - Owner: https://github.com/ryusudol
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7d718a56af494e96161cbc72791dbc7e9d3db5d0 -
Trigger Event:
release
-
Statement type: