Skip to main content

No project description provided

Project description

simtorch

Downloads

A Pytorch library to measure the similarity between two neural network representations. The library currently supports the following (dis)similarity measures:

Design

The package consists of two components -

  • SimilarityModel - which is a thin wrapper on torch.nn.Module() which adds forwards hooks to store the layer-wise activations (aka representations) in a dictionary.
  • BaseSimilarity - which sets the interface for classes that compute similarity between network representations

Installation

The package is indexed by pypi

pip install simtorch

Usage

The torch model objects need to be wrapped with SimilarityModel. A list of names of the layers we wish to compute the representations is passed as an attribute to this class.

model1 = torchvision.models.densenet121()
model2 = torchvision.models.resnet101()

sim_model1 = SimilarityModel(
    model1,
    model_name="DenseNet 121",
    layers_to_include=["conv", "classifier",]
)

sim_model2 = SimilarityModel(
    model2,
    model_name="ResNet 101",
    layers_to_include=["conv", "fc",]
)

An instance of a similarity metric can then be initialized with these SimilarityModels. The compute() method can be used to obtain a similarity matrix $S$ for these two models where $S[i, j]$ is the similarity metric for the $i^{th}$ layer of the first model and the $j^{th}$ layer of the second model.

sim_cka = CKA(sim_model1, sim_model2, device="cuda")
cka_matrix = sim_cka.compute(torch_dataloader)

The similarity matrix can be visualized using the sim_cka.plot_similarity() method to obtain the CKA similarity plot

Centered Kernel Alignment Matrix

Citations

If you use Deconfounded Centered Kernel Alignment (dCKA) for your research, please cite:

@article{cui2022deconfounded,
  title={Deconfounded Representation Similarity for Comparison of Neural Networks},
  author={Cui, Tianyu and Kumar, Yogesh and Marttinen, Pekka and Kaski, Samuel},
  journal={Neural Information Processing Systems (NeurIPS)},
  year={2022}
}

Credits

This has been built by using the following awesome repos as reference:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simtorch-0.2.1.tar.gz (12.4 kB view details)

Uploaded Source

Built Distribution

simtorch-0.2.1-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file simtorch-0.2.1.tar.gz.

File metadata

  • Download URL: simtorch-0.2.1.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for simtorch-0.2.1.tar.gz
Algorithm Hash digest
SHA256 bc3846313d888e256ca78e166c602370d2648d1294acb881bf8736a1281c58c9
MD5 179e4f0e0d198e3bd272b0ad435c53fc
BLAKE2b-256 472ad512a9cd232b4d417a5ef61e4503035a9c3a49446687e0e42063c2a00f5a

See more details on using hashes here.

File details

Details for the file simtorch-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: simtorch-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.8

File hashes

Hashes for simtorch-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ea870e1fea7ee23c0f86503c985c07417f94ff18074da2e56d7b65b2a50b94a1
MD5 38eb57fd38fd56fb703f3a19094e7b9b
BLAKE2b-256 db256d02095bcdd1b3483b8e9053114bbdd02f4e46ed40c5acb0e20408ff8db7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page