Skip to main content

pytorch-optimized geometric median implementation

Project description

torch-geometric-median

ci PyPI

A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!

Installation

pip install torch-geometric-median

Usage

This library exports a single function, geometric_median, which takes a tensor of shape (N, D) where N is the number of samples, and D is the size of each sample, and returns the geometric median of the points in the tensor .

from torch_geometric_median import geometric_median

# Create a tensor of points
points = torch.tensor([
    [0.0, 0.0],
    [1.0, 1.0],
    [2.0, 2.0],
    [3.0, 3.0],
    [4.0, 4.0],
])

# Compute the geometric median
median = geometric_median(points).median

Backprop

Like the original geom-median library, this library supports backpropagation through the geometric median computation.

median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`

Extra options

The geometric_median function also supports a few extra options:

  • maxiter: The maximum number of iterations to run the optimization for. Default is 100.
  • ftol: If objective value does not improve by at least this ftol fraction, terminate the algorithm. Default 1e-20.
  • weights: A tensor of shape (N,) containing the weights for each point, where N is the number of samples. Default is None, which means all points are weighted equally.
  • show_progress: If True, show a progress bar for the optimization. Default is False.
  • log_objective_values: If True, log the objective value at each iteration under the key objective_values_log. Default is False.
median = geometric_median(
    points,
    maxiter=1000,
    ftol=1e-10,
    weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
    show_progress=True,
    log_objective_values=True
).median

Why does this library exist?

It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.

Acknowledgements

This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.

License

This library is licensed under a GPL license, as per the original geom-median library.

Contributing

Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_geometric_median-0.1.1.tar.gz (16.7 kB view hashes)

Uploaded Source

Built Distribution

torch_geometric_median-0.1.1-py3-none-any.whl (16.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page