Skip to main content

Large Margin Nearest Neighbor implementation in python

Project description

PyLMNN is an implementation of the Large Margin Nearest Neighbor algorithm for metric learning in pure python.

This implementation follows closely the original MATLAB code by Kilian Weinberger found at https://bitbucket.org/mlcircus/lmnn. This version solves the unconstrained optimisation problem and finds a linear transformation using L-BFGS as the backend optimizer.

This package can also find optimal hyper-parameters for LMNN via Bayesian Optimization using the excellent GPyOpt package.

Installation

The code was developed in python 3.5 under Ubuntu 16.04 and was also tested under Ubuntu 18.04 and python 3.6. You can clone the repo with:

git clone https://github.com/johny-c/pylmnn.git

or install it via pip:

pip3 install pylmnn

Dependencies

  • numpy>=1.11.2

  • scipy>=0.18.1

  • scikit_learn>=0.18.1

Optional dependencies

In case you want to use the hyperparameter optimization module, you should also install:

  • GPy>=1.5.6

  • GPyOpt>=1.0.3

Usage

Here is a minimal use case:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

from pylmnn import LargeMarginNearestNeighbor as LMNN


# Load a data set
X, y = load_iris(return_X_y=True)

# Split in training and testing set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.7, stratify=y, random_state=42)

# Set up the hyperparameters
k_train, k_test, n_components, max_iter = 3, 3, X.shape[1], 180

# Instantiate the metric learner
lmnn = LMNN(n_neighbors=k_train, max_iter=max_iter, n_components=n_components)

# Train the metric learner
lmnn.fit(X_train, y_train)

# Fit the nearest neighbors classifier
knn = KNeighborsClassifier(n_neighbors=k_test)
knn.fit(lmnn.transform(X_train), y_train)

# Compute the k-nearest neighbor test accuracy after applying the learned transformation
lmnn_acc = knn.score(lmnn.transform(X_test), y_test)
print('LMNN accuracy on test set of {} points: {:.4f}'.format(X_test.shape[0], lmnn_acc))

You can check the examples directory for a demonstration of how to use the code with different datasets and how to estimate good hyperparameters with Bayesian Optimisation.

Documentation can also be found at http://pylmnn.readthedocs.io/en/latest/ .

References

If you use this code in your work, please cite the following publication.

@ARTICLE{weinberger09distance,
    title={Distance metric learning for large margin nearest neighbor classification},
    author={Weinberger, K.Q. and Saul, L.K.},
    journal={The Journal of Machine Learning Research},
    volume={10},
    pages={207--244},
    year={2009},
    publisher={MIT Press}
}

License and Contact

This work is released under the 3-Clause BSD License.

Contact John Chiotellis :envelope: for questions, comments and reporting bugs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

PyLMNN-1.6.4.tar.gz (18.8 kB view details)

Uploaded Source

File details

Details for the file PyLMNN-1.6.4.tar.gz.

File metadata

  • Download URL: PyLMNN-1.6.4.tar.gz
  • Upload date:
  • Size: 18.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.2

File hashes

Hashes for PyLMNN-1.6.4.tar.gz
Algorithm Hash digest
SHA256 84e54a48b0bcdc2b2ed4c9017dfb60288fb7ab869b7e84a8404a2f6b627eadfe
MD5 046fa2d3888f3f9fd99099770a2a4d93
BLAKE2b-256 71920473d6a03157e72a4eba5578f8ce97d9547d6404e7104b62d14b7c12b3e1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page