Skip to main content

pytorch-optimized geometric median implementation

Project description

torch-geometric-median

ci PyPI

A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!

Installation

pip install torch-geometric-median

Usage

This library exports a single function, geometric_median, which takes a tensor of shape (N, D) where N is the number of samples, and D is the size of each sample, and returns the geometric median of the points in the tensor .

from torch_geometric_median import geometric_median

# Create a tensor of points
points = torch.tensor([
    [0.0, 0.0],
    [1.0, 1.0],
    [2.0, 2.0],
    [3.0, 3.0],
    [4.0, 4.0],
])

# Compute the geometric median
median = geometric_median(points).median

Backprop

Like the original geom-median library, this library supports backpropagation through the geometric median computation.

median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`

Extra options

The geometric_median function also supports a few extra options:

  • maxiter: The maximum number of iterations to run the optimization for. Default is 100.
  • ftol: If objective value does not improve by at least this ftol fraction, terminate the algorithm. Default 1e-20.
  • weights: A tensor of shape (N,) containing the weights for each point, where N is the number of samples. Default is None, which means all points are weighted equally.
  • show_progress: If True, show a progress bar for the optimization. Default is False.
  • log_objective_values: If True, log the objective value at each iteration under the key objective_values_log. Default is False.
median = geometric_median(
    points,
    maxiter=1000,
    ftol=1e-10,
    weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
    show_progress=True,
    log_objective_values=True
).median

Why does this library exist?

It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.

Acknowledgements

This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.

License

This library is licensed under a GPL license, as per the original geom-median library.

Contributing

Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_geometric_median-0.1.1.tar.gz (16.7 kB view details)

Uploaded Source

Built Distribution

torch_geometric_median-0.1.1-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_geometric_median-0.1.1.tar.gz.

File metadata

File hashes

Hashes for torch_geometric_median-0.1.1.tar.gz
Algorithm Hash digest
SHA256 00316a0237471d0628516db407a34fc99e04504b62575fd1c03670e839aba2c0
MD5 9a025a8606fb4454f410a45de193d01b
BLAKE2b-256 781fa593a3f2f04ed90ca48b6bf4d614d52cb502cd50db6c4925a5e579ab2d16

See more details on using hashes here.

File details

Details for the file torch_geometric_median-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_geometric_median-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 20c26ef7a40d3cc3f6f29fa4cb9f78eebf56e8c7e702f4fb8e2d8b2910b3c386
MD5 768752594d26ee938beb0af5be830f75
BLAKE2b-256 06f3b70c4295c91e1ef306a6d9b30a028072cf7e0f6bab29e4339a6cd2299fe2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page