An easy-to-use interface for (approximate) nearest neighbors algorithms.
Project description
nearness is a unified interface for (approximate) nearest neighbors search.
Using pip install nearness
only installs the interface and does not add any concrete nearest
neighbors search implementation. The following implementations are available currently:
- Annoy exposes
AnnoyNeighbors
- AutoFaiss exposes
AutoFaissNeighbors
- Faiss exposes
FaissNeighbors
- PyGlass exposes
GlassNeighbors
- HNSWLib exposes
HNSWNeighbors
- Jax exposes
JaxNeighbors
- Numpy exposes
NumpyNeighbors
- ScaNN exposes
ScannNeighbors
- SciPy exposes
ScipyNeighbors
- scikit-learn exposes
SklearnNeighbors
- PyTorch exposes
TorchNeighbors
Installing one of the above packages exposes the corresponding nearest neighbors implementation. For example,
nearness.FaissNeighbors
is available if Faiss is installed.
Another option to install the underlying packages is to specify them as package extras, e.g.
pip install nearness[faiss]
installs the nearness with faiss-cpu
. If you require flexibility regarding
the specific version of the installed packages, it's recommended to install them explicitly.
API
The nearness API consists of a single class called NearestNeighbors
with the following methods.
def fit(data: np.ndarray) -> Self:
"""Learn an index structure based on a matrix of points."""
...
def query(point: np.ndarray, n_neighbors: int) -> tuple[np.ndarray, np.ndarray]:
"""Search ``n_neighbors`` for a single point, returning the indices and distances."""
...
def query_batch(points: np.ndarray, n_neighbors: int) -> tuple[np.ndarray, np.ndarray]:
"""Search ``n_neighbors`` for a batch of points, returning the indices and distances."""
...
def save(file: str | Path) -> None:
"""Save the state of the model using pickle such that it can be fully restored."""
...
def load(file: str | Path) -> None:
"""Load a model using pickle to fully restore the saved state."""
...
The interface to all methods is based on NumPy arrays, but implementations might
overload
the methods such that other data types are supported. For example, TorchNeighbors
supports NumPy and
PyTorch arrays.
The library additionally exports a global config
object, of which the current state is passed to any
NearestNeighbors
class instantiation. Any modifications of a class-bound config is then specific to the class
and does not modify the global object.
In addition to the global config, we treat all of the __init__
arguments to NearestNeighbors
as parameters
of the class, automagically binding the parameters to an object before instantiation. We expose the
config and parameters of an object as obj.config
and obj.parameters
.
Usage Example
The following example demonstrates how to use nearness given that scikit-learn is installed.
from nearness import SklearnNeighbors
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
X, _ = load_digits(return_X_y=True)
X_train, X_test = train_test_split(X)
# create a brute force nearest neighbors model
model = SklearnNeighbors(algorithm="brute")
model.fit(X_train)
# query a single test point
idx, dist = model.query(X_test[0], n_neighbors=5)
# query all test points
idx_batch, dist_batch = model.query_batch(X_test, n_neighbors=5)
# change the algorithm to a K-D tree and fit again
model.parameters.algorithm = "kd_tree"
model.fit(X_train)
# save the model to a file
model.save("my_sklearn_model")
# load the model from file
kdtree_model = SklearnNeighbors.load("my_sklearn_model")
# query again using the loaded model
kdtree_model.query(X_test[0], n_neighbors=5)
Algorithm Implementation
To define your own NearestNeighbors
algorithm it is only necessary to implement above specified fit
and
query
methods. By default, query_batch
uses a joblib to process a batch of queries in a threadpool, but most of
the time you'd want to implement query_batch
on your own for improved efficiency.
The following example illustrates the concepts of config
and parameters
.
class MyNearestNeighbors(NearestNeighbors):
# only keyword-only arguments are allowed for subclasses of ``NearestNeighbors``.
def __init__(self, *, a: int = 0):
# the __init__ parameters are injected as ``parameters``
print(self.parameters.a) # 0
# the parameters can be modified as needed
self.parameters.a += 1
print(self.parameters.a) # 1
# a copy of the current global configuration is injected as ``config``
print(self.config.save_compression) # 0
# the configuration can be modified as needed (does not modify the global config)
self.config.save_compression = 1
print(self.config.save_compression) # 1
def fit(self, data: np.ndarray) -> "Self":
...
def query(self, point: np.ndarray, n_neighbors: int) -> tuple[np.ndarray, np.ndarray]:
...
An interesting configuration aspect is methods_require_fit
, which specifies the set of methods that require a
successful call of fit
before they can be used. By default, the query methods are listed in
methods_require_fit
, and, if a query method is called before fit
, an informative error message is shown.
A successful fit additionally sets the is_fitted
property to True
and removes the fit checks such that
there is zero overhead for queries. Manually setting is_fitted
to False
again adds the
checks to all methods specified in methods_require_fit
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.