Skip to main content

Large Margin Nearest Neighbor implementation in python

Project description

PyLMNN is an implementation of the Large Margin Nearest Neighbor algorithm for metric learning in pure python.

This implementation follows closely the original MATLAB code by Kilian Weinberger found at https://bitbucket.org/mlcircus/lmnn. This version solves the unconstrained optimisation problem and finds a linear transformation using L-BFGS as the backend optimizer.

This package also uses Bayesian Optimization to find the optimal hyper-parameters for LMNN using the excellent GPyOpt package.

Installation

The code was developed in python 3.5 under Ubuntu 16.04. You can clone the repo with:

git clone https://github.com/johny-c/pylmnn.git

or install it via pip:

pip3 install pylmnn

Dependencies

  • numpy>=1.11.2

  • scipy>=0.18.1

  • scikit_learn>=0.18.1

  • GPy>=1.5.6

  • GPyOpt>=1.0.3

  • matplotlib>=1.5.3

Usage

Here is a minimal use case:

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

from pylmnn.lmnn import LargeMarginNearestNeighbor as LMNN
from pylmnn.plots import plot_comparison


# Load a data set
dataset = load_iris()
X, y = dataset.data, dataset.target

# Split in training and testing set
x_tr, x_te, y_tr, y_te = train_test_split(X, y, test_size=0.7, stratify=y, random_state=42)

# Set up the hyperparameters
k_tr, k_te, dim_out, max_iter = 3, 1, X.shape[1], 180

# Instantiate the classifier
clf = LMNN(n_neighbors=k_tr, max_iter=max_iter, n_features_out=dim_out)

# Train the classifier
clf = clf.fit(x_tr, y_tr)

# Compute the k-nearest neighbor test accuracy after applying the learned transformation
accuracy_lmnn = clf.score(x_te, y_te)
print('LMNN accuracy on test set of {} points: {:.4f}'.format(x_te.shape[0], accuracy_lmnn))

# Draw a comparison plot of the test data before and after applying the learned transformation
plot_comparison(clf.L, x_te, y_te, dim_pref=3)

You can check the examples directory for a demonstration of how to use the code with different datasets and how to estimate good hyperparameters with Bayesian Optimisation.

Documentation can also be found at http://pylmnn.readthedocs.io/en/latest/ .

References

If you use this code in your work, please cite the following publication.

@ARTICLE{weinberger09distance,
    title={Distance metric learning for large margin nearest neighbor classification},
    author={Weinberger, K.Q. and Saul, L.K.},
    journal={The Journal of Machine Learning Research},
    volume={10},
    pages={207--244},
    year={2009},
    publisher={MIT Press}
}

License and Contact

This work is released under the 3-Clause BSD License.

Contact John Chiotellis :envelope: for questions, comments and reporting bugs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

PyLMNN-1.5.1.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

PyLMNN-1.5.1-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file PyLMNN-1.5.1.tar.gz.

File metadata

  • Download URL: PyLMNN-1.5.1.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for PyLMNN-1.5.1.tar.gz
Algorithm Hash digest
SHA256 2e4dbb0c7d634e1a8d659c7196f896cff2a497751be60243dd347d847295d319
MD5 ee55e74927d1e552e73a51c32b1a5b8e
BLAKE2b-256 82f815e58dcebe9c34c928bfd0b0220ced18d9a8365cfcd9b90f1bbad39cbb22

See more details on using hashes here.

File details

Details for the file PyLMNN-1.5.1-py3-none-any.whl.

File metadata

File hashes

Hashes for PyLMNN-1.5.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6b4a015c02ead42357d1d78d4a0233be97391bd604235188e015619b1dff65ad
MD5 ee02cc0805c31b5a6883e3933ce52ed5
BLAKE2b-256 4f2c7f3b028369c259a2c09b065a861389df49612fc9565201d0ed44ea685b81

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page