Skip to main content

Centered Kernel Alignment (CKA) with Efficient Computation and Layer-wise Visualization for PyTorch

Project description

pytorch-cka

PyPI Python PyPI Downloads

The Fastest, Memory-efficient Python Library for computing layer-wise similarity between neural network models

A bar chart with benchmark results in light mode

44x faster CKA computation across 18 representational layers of ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs

  • ⚡️ Fastest among CKA libraries thanks to vectorized ops & GPU acceleration
  • 📦 Efficient memory management with explicit deallocation
  • 🧠 Supports HuggingFace models, DataParallel, and DDP
  • 🎨 Customizable visualizations: heatmaps and line charts

📦 Installation

Requires Python 3.10+

# Using pip
pip install pytorch-cka

# Using uv
uv add pytorch-cka

👟 Quick Start

Basic Usage

from cka import compute_cka
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34

resnet_18 = resnet18(pretrained=True)
resnet_34 = resnet34(pretrained=True)

dataloader1 = Dataloader(your_dataset1, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader2 = Dataloader(your_dataset2, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader3 = Dataloader(your_dataset3, batch_size=bach_size, shuffle=False, num_workers=4)
dataloaders = [dataloader1, dataloader2, dataloader3]

layers = [
    'conv1',
    'layer1.0.conv1',
    'layer2.0.conv1',
    'layer3.0.conv1',
    'layer4.0.conv1',
    'fc',
]

cka_matrices = compute_cka(
    resnet_18,
    resnet_34,
    dataloaders,
    layers=layers,
    device=device,
)

for cka_matrix in cka_matrices:
    print(cka_matrix)

Visualization

Heatmap

from cka import plot_cka_heatmap

fig, ax = plot_cka_heatmap(
    cka_matrix,
    layers1=layers,
    layers2=layers,
    model1_name="ResNet-18 (pretrained)",
    model2_name="ResNet-18 (random init)",
    annot=False,          # Show values in cells
    cmap="inferno",       # Colormap
)
Self-comparison heatmap Cross-model comparison heatmap
Self-comparison Cross-model

Trend Plot

from cka import plot_cka_trend

# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)

fig, ax = plot_cka_trend(
    layer_trends,
    x_values=epochs,
    labels=RESNET18_LAYERS,
    markers=['o'],
    xlabel='Epoch',
    ylabel='CKA Score',
    title='Pretrained vs. Fine-tuned Across Epochs (ResNet-18)',
    legend=True,
)

fig, ax = plot_cka_layer_trend(
    cka_matrices,
    layers=RESNET18_LAYERS,
    labels=cka_loader_names,
    ylabel='CKA Score',
    title='Pretrained vs. Fine-tuned Across Layers (ResNet-18)',
    legend=True,
)
CKA Score Trend Across Epochs CKA Score Trend Across Layers
CKA Score Trend Across Epochs CKA Score Trend Across Layers

📚 References

Kornblith, Simon, et al. "Similarity of Neural Network Representations Revisited." ICML 2019.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_cka-1.0.1.tar.gz (264.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_cka-1.0.1-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_cka-1.0.1.tar.gz.

File metadata

  • Download URL: pytorch_cka-1.0.1.tar.gz
  • Upload date:
  • Size: 264.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_cka-1.0.1.tar.gz
Algorithm Hash digest
SHA256 486504f8bda401173846d24462640c24659ed0f8f23984152c620b6b3bd9853d
MD5 81f28b483465f1d5083989397b145365
BLAKE2b-256 4318c1d1144781d567b0a98a89cac37923d94ead11f104ea61da3c7257307e02

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_cka-1.0.1.tar.gz:

Publisher: publish.yaml on ryusudol/Centered-Kernel-Alignment

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_cka-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_cka-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 14.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_cka-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e247ca79016ab07e0f72c05f78f55b477a67b92befd2acfc6f55f433dc156e34
MD5 3bf70d2c5e4a69eac6eb259320a16f13
BLAKE2b-256 93cb1e4973264157862c2b6331e092377d35d23fb48efe994e3acb6230c439ba

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_cka-1.0.1-py3-none-any.whl:

Publisher: publish.yaml on ryusudol/Centered-Kernel-Alignment

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page