No project description provided
Project description
simtorch
A Pytorch library to measure the similarity between two neural network representations. The library currently supports the following (dis)similarity measures:
- Centered Kernel Alignment (CKA) - Kornblith, et al, ICML 2019
- Deconfounded CKA - Cui, et al, NeurIPS 2022
- Procrustes [WIP]
- CCA [WIP]
Design
The package consists of two components -
SimilarityModel
- which is a thin wrapper ontorch.nn.Module()
which adds forwards hooks to store the layer-wise activations (aka representations) in a dictionary.BaseSimilarity
- which sets the interface for classes that compute similarity between network representations
Installation
The package is indexed by pypi
pip install simtorch
Usage
The torch model objects need to be wrapped with SimilarityModel
. A list of names of the layers we wish to compute the representations is passed as an attribute to this class.
model1 = torchvision.models.densenet121()
model2 = torchvision.models.resnet101()
sim_model1 = SimilarityModel(
model1,
model_name="DenseNet 121",
layers_to_include=["conv", "classifier",]
)
sim_model2 = SimilarityModel(
model2,
model_name="ResNet 101",
layers_to_include=["conv", "fc",]
)
An instance of a similarity metric can then be initialized with these SimilarityModel
s. The compute()
method can be used to obtain a similarity matrix $S$ for these two models where $S[i, j]$ is the similarity metric for the $i^{th}$ layer of the first model and the $j^{th}$ layer of the second model.
sim_cka = CKA(sim_model1, sim_model2, device="cuda")
cka_matrix = sim_cka.compute(torch_dataloader)
The similarity matrix can be visualized using the sim_cka.plot_similarity()
method to obtain the CKA similarity plot
Citations
If you use Deconfounded Centered Kernel Alignment (dCKA) for your research, please cite:
@article{cui2022deconfounded,
title={Deconfounded Representation Similarity for Comparison of Neural Networks},
author={Cui, Tianyu and Kumar, Yogesh and Marttinen, Pekka and Kaski, Samuel},
journal={Neural Information Processing Systems (NeurIPS)},
year={2022}
}
Credits
This has been built by using the following awesome repos as reference:
- anatome, maintained by @moskomule
- Pytorch-Model-Compare, maintained by @AntixK
- centered-kernel-alignment, maintained by @Kennethborup
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file simtorch-0.2.1.tar.gz
.
File metadata
- Download URL: simtorch-0.2.1.tar.gz
- Upload date:
- Size: 12.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc3846313d888e256ca78e166c602370d2648d1294acb881bf8736a1281c58c9 |
|
MD5 | 179e4f0e0d198e3bd272b0ad435c53fc |
|
BLAKE2b-256 | 472ad512a9cd232b4d417a5ef61e4503035a9c3a49446687e0e42063c2a00f5a |
File details
Details for the file simtorch-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: simtorch-0.2.1-py3-none-any.whl
- Upload date:
- Size: 13.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ea870e1fea7ee23c0f86503c985c07417f94ff18074da2e56d7b65b2a50b94a1 |
|
MD5 | 38eb57fd38fd56fb703f3a19094e7b9b |
|
BLAKE2b-256 | db256d02095bcdd1b3483b8e9053114bbdd02f4e46ed40c5acb0e20408ff8db7 |