Skip to main content

PyTorch implementation of the InfoNCE loss for self-supervised learning.

Project description

PyTorch implementation of the InfoNCE loss from “Representation Learning with Contrastive Predictive Coding”. In contrastive learning, we want to learn how to map high dimensional data to a lower dimensional embedding space. This mapping should place semantically similar samples close together in the embedding space, whilst placing semantically distinct samples further apart. The InfoNCE loss function can be used for the purpose of contrastive learning.

This package is available on PyPI and can be installed via:

pip install info-nce-pytorch

Example usage

Can be used without explicit negative keys, whereby each sample is compared with the other samples in the batch.

loss = InfoNCE()
batch_size, embedding_size = 32, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
output = loss(query, positive_key)

Can be used with negative keys, whereby every combination between query and negative key is compared.

loss = InfoNCE(negative_mode='unpaired') # negative_mode='unpaired' is the default value
batch_size, num_negative, embedding_size = 32, 48, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Can be used with negative keys, whereby each query sample is compared with only the negative keys it is paired with.

loss = InfoNCE(negative_mode='paired')
batch_size, num_negative, embedding_size = 32, 6, 128
query = torch.randn(batch_size, embedding_size)
positive_key = torch.randn(batch_size, embedding_size)
negative_keys = torch.randn(batch_size, num_negative, embedding_size)
output = loss(query, positive_key, negative_keys)

Loss graph

Suppose we have some initial mean vectors µ_q, µ_p, µ_n and a covariance matrix Σ = I/10, then we can plot the value of the InfoNCE loss by sampling from distributions with interpolated mean vectors. Given interpolation weights α and β, we define the distribution Q ~ N(µ_q, Σ) for the query samples, the distribution P_α ~ N(αµ_q + (1-α)µ_p, Σ) for the positive samples and the distribution N_β ~ N(βµ_q + (1-β)µ_n, Σ) for the negative samples. Shown below is the value of the loss with inputs sampled from the distributions defined above for different values of α and β.

https://raw.githubusercontent.com/RElbers/info-nce-pytorch/main/imgs/loss.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

info-nce-pytorch-0.1.4.tar.gz (4.3 kB view details)

Uploaded Source

Built Distribution

info_nce_pytorch-0.1.4-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file info-nce-pytorch-0.1.4.tar.gz.

File metadata

  • Download URL: info-nce-pytorch-0.1.4.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for info-nce-pytorch-0.1.4.tar.gz
Algorithm Hash digest
SHA256 e1c3019be1238dd42cb2a6f71825483869a78340585a6772b154fd134d2b48c5
MD5 850426ab40a8a066c564079bc8acbafb
BLAKE2b-256 1a6139bc456f50428f74a95f37c020dfb2ee70954dbec81cc7e9b97f50305a62

See more details on using hashes here.

File details

Details for the file info_nce_pytorch-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: info_nce_pytorch-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 4.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for info_nce_pytorch-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 1dd232e7503828bdee1bd7ea29d92ac4b3c6c5d15e53adbdb1830574d992ba79
MD5 e1ecede277933b116dba30bf085a5c6f
BLAKE2b-256 8ec6435455387a6c32fa10be9311affd5dc9bc6208cb80f1f9cf5b84ac5c58a5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page