Skip to main content

Neighbourhood Components Analysis in PyTorch.

Project description

torchnca

A PyTorch implementation of Neighbourhood Components Analysis by J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov.

NCA learns a linear transformation of the dataset such that the expected leave-one-out performance of kNN in the transformed space is maximized.

Installation

You can install torchnca with pip:

pip install torchnca

API

from torchnca import NCA

# instantiate torchnca object and initialize with
# an identity matrix
nca = NCA(dim=2, init="identity")

# fit an torchnca model to a dataset
# normalize the input data before
# running the optimization
nca.train(X, y, batch_size=64, normalize=True)

# apply the learned linear map to the data
X_nca = nca(X)

Dimensionality Reduction

We generate a 3-D dataset where the first 2 dimensions are concentric rings and the third dimension is Gaussian noise. We plot the result of PCA, LDA and NCA with 2 components.

Notice how PCA has failed to project out the noise, a result of a high noise variance in the third dimension. LDA also struggles to recover the concentric pattern since the classes themselves are not linearly separable.

kNN on MNIST

We compute the classification error, computation time and storage cost of two algorithms:

  • kNN (k = 5) on the raw 784 dimensional MNIST dataset
  • kNN (k = 5) on a learned 32 dimensional NCA projection of the MNIST dataset
Method NCA + kNN Raw kNN
Time 2.37s 155.25s
Storage 6.40 Mb 156.8 Mb
Error 3.3% 2.8%

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchnca-0.1.1.tar.gz (5.0 kB view details)

Uploaded Source

File details

Details for the file torchnca-0.1.1.tar.gz.

File metadata

  • Download URL: torchnca-0.1.1.tar.gz
  • Upload date:
  • Size: 5.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.3

File hashes

Hashes for torchnca-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9314699e5d16194e3ea0385c8ac2b2937e48cc6027446d1338162df8047284f0
MD5 6a28912ffe51e26bb7c199fe3da93355
BLAKE2b-256 315f514e744695a21fa5f939dde86027eb7ece8bb018051c693ffa75f0a25009

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page