Centered Kernel Alignment in PyTorch
Project description
[!WARNING] This repository has been built mainly for personal and academic use since
Captum still needs to implement its variant of CKA. As such, do not expect this project to work for every model.
✒️ About
[!NOTE] Centered Kernel Alignment (CKA) [1] is a similarity index between representations of features in neural networks, based on the Hilbert-Schmidt Independence Criterion (HSIC) [2]. Given a set of examples, CKA compares the representations of examples passed through the layers that we want to compare.
Given two matrices $X \in \mathbb{R}^{n\times s_1}$ and $Y \in \mathbb{R}^{n\times s_2}$ representing the output of two layers, we can define two auxiliary $n \times n$ Gram matrices $K=XX^T$ and $L=YY^T$ and compute the dot-product similarity between them
$$\langle vec(XX^T), vec(YY^T)\rangle = tr(XX^T YY^T) = \lVert Y^T X \rVert_F^2.$$
Then, the $HSIC$ on $K$ and $L$ is defined as
$$HSIC_0(K, L) = \frac{tr(KHLH)}{(n - 1)^2},$$
where $H = I_n - \frac{1}{n}J_n$ is the centering matrix and $J_n$ is an $n \times n$ matrix filled with ones. Finally, to obtain the CKA value we only need to normalize $HSIC_0$
$$CKA(K, L) = \frac{HSIC(K, L)}{\sqrt{HSIC(K, K) HSIC(L, L)}}.$$
[!NOTE] However, naive computation of linear CKA (i.e.: the previous equation) requires maintaining the activations across the entire dataset in memory, which is challenging for wide and deep networks [3].
Therefore, we need to define the unbiased estimator of HSIC so that the value of CKA is independent of the batch size
$$HSIC_1(K, L)=\frac{1}{n(n-3)}\left( tr(\tilde{K}, \tilde{L}) + \frac{1^T\tilde{K}11^T\tilde{L}1}{(n-1)(n-2)} - \frac{2}{n-2}1^T\tilde{K}\tilde{L}1\right),$$
where $\tilde{K}$ and $\tilde{L}$ are obtained by setting the diagonal entries of $K$ and $L$ to zero. Finally, we can compute the minibatch version of CKA by averaging $HSIC_1$ scores over $k$ minibatches
$$CKA_{minibatch}=\frac{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(K_i, L_i)}{\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(K_i, K_i)}\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(L_i, L_i)}},$$
with $K_i=X_iX_i^T$ and $L_i=Y_iY_i^T$, where $X_i \in \mathbb{R}^{m \times p_1}$ and $Y_i \in \mathbb{R}^{m \times p_2}$ are now matrices containing activations of the $i^{th}$ minibatch of $m$ examples sampled without replacement [3].
📦 Installation
This project requires python >= 3.10.
Create a new venv
# If you have uv installed
uv venv
# Otherwise
python -m venv .venv
# Activate the virtual environment
source .venv/bin/activate # if you are on Linux
.\.venv\Scripts\activate.bat # if you are using the cmd on Windows
.\.venv\Scripts\Activate.ps1 # if you are using the PowerShell on Windows
Install the package
[!NOTE] Please refer to uv
PyTorch integration.
You can install the package:
-
from PyPI
uv pip install ckatorch --torch-backend=auto
-
from this repo
uv pip install git+https://github.com/RistoAle97/centered-kernel-alignment --torch-backend=auto
-
by cloning the repository and installing the dependencies
git clone https://github.com/RistoAle97/centered-kernel-alignment cd centered-kernel-alignment uv pip install -e . --torch-backend=auto uv pip install ckatorch --group dev # if you want to also install the dev dependencies
Take a look at the examples directory to understand how to compute CKA in two basic scenarios.
🖼️ Plots
[!NOTE] The comparison makes more sense if the models share a common architecture.
| Model compared with itself | Different models compared |
|---|---|
📚 Bibliography
[1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International Conference on Machine Learning. PMLR, 2019.
[2] Wang, Tinghua, Xiaolu Dai, and Yuze Liu. "Learning with Hilbert–Schmidt independence criterion: A review and new perspectives." Knowledge-based systems 234 (2021): 107567.
[3] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." arXiv preprint arXiv:2010.15327 (2020).
This project is also based on the following repositories:
- representation_similarity (original implementation).
- PyTorch-Model-Compare (nice PyTorch implementation that employs hooks).
- CKA.pytorch (minibatch version of CKA and useful batched implementation of $HSIC_1$).
📝 License
This project is MIT licensed.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ckatorch-1.0.3.tar.gz.
File metadata
- Download URL: ckatorch-1.0.3.tar.gz
- Upload date:
- Size: 186.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
69992c675acc6b815bf731a81cf75ce1e4d13486ee716b595698fe0dbb43b7f5
|
|
| MD5 |
e1b5485d2b82c7aeece953f62fece78c
|
|
| BLAKE2b-256 |
1fe65156375fb8842dd0f75ed8ff4bef45dda3c40abc18aa5d32d187c4c65399
|
Provenance
The following attestation bundles were made for ckatorch-1.0.3.tar.gz:
Publisher:
publish-to-pypi.yml on RistoAle97/centered-kernel-alignment
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ckatorch-1.0.3.tar.gz -
Subject digest:
69992c675acc6b815bf731a81cf75ce1e4d13486ee716b595698fe0dbb43b7f5 - Sigstore transparency entry: 926796501
- Sigstore integration time:
-
Permalink:
RistoAle97/centered-kernel-alignment@dfb88efdd6e5b94827c9788161983490cd53ec94 -
Branch / Tag:
refs/tags/1.0.3 - Owner: https://github.com/RistoAle97
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@dfb88efdd6e5b94827c9788161983490cd53ec94 -
Trigger Event:
release
-
Statement type:
File details
Details for the file ckatorch-1.0.3-py3-none-any.whl.
File metadata
- Download URL: ckatorch-1.0.3-py3-none-any.whl
- Upload date:
- Size: 16.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
78641cfb7b6595b905716b842a434cbe987d3dbeb675c65b968a08ac19b8edbb
|
|
| MD5 |
7a271443e9cac8ce527a1349d5c1f224
|
|
| BLAKE2b-256 |
e016db51612ae905c695465af3f14e32a33c93001d7d1f3f3ba84a04bd0931a5
|
Provenance
The following attestation bundles were made for ckatorch-1.0.3-py3-none-any.whl:
Publisher:
publish-to-pypi.yml on RistoAle97/centered-kernel-alignment
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ckatorch-1.0.3-py3-none-any.whl -
Subject digest:
78641cfb7b6595b905716b842a434cbe987d3dbeb675c65b968a08ac19b8edbb - Sigstore transparency entry: 926796503
- Sigstore integration time:
-
Permalink:
RistoAle97/centered-kernel-alignment@dfb88efdd6e5b94827c9788161983490cd53ec94 -
Branch / Tag:
refs/tags/1.0.3 - Owner: https://github.com/RistoAle97
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish-to-pypi.yml@dfb88efdd6e5b94827c9788161983490cd53ec94 -
Trigger Event:
release
-
Statement type: