pytorch-optimized geometric median implementation
Project description
torch-geometric-mean
A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!
Installation
pip install torch-geometric-median
Usage
This library exports a single function, geometric_median
, which takes a tensor of shape (N, D)
where N
is the number of samples, and D
is the size of each sample, and returns the geometric median of the points in the tensor .
from torch_geometric_median import geometric_median
# Create a tensor of points
points = torch.tensor([
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
])
# Compute the geometric median
median = geometric_median(points).median
Backprop
Like the original geom-median library, this library supports backpropagation through the geometric median computation.
median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`
Extra options
The geometric_median
function also supports a few extra options:
maxiter
: The maximum number of iterations to run the optimization for. Default is 100.ftol
: If objective value does not improve by at least thisftol
fraction, terminate the algorithm. Default 1e-20.weights
: A tensor of shape(N,)
containing the weights for each point, whereN
is the number of samples. Default isNone
, which means all points are weighted equally.show_progress
: IfTrue
, show a progress bar for the optimization. Default isFalse
.log_objective_values
: IfTrue
, log the objective value at each iteration under the keyobjective_values_log
. Default isFalse
.
median = geometric_median(
points,
maxiter=1000,
ftol=1e-10,
weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
show_progress=True,
log_objective_values=True
).median
Why does this library exist?
It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.
Acknowledgements
This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.
License
This library is licensed under a GPL license, as per the original geom-median library.
Contributing
Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_geometric_median-0.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | be8bfb32b3844bd27644752093d03f5239d071d283aabd88f10b92332aa79cf4 |
|
MD5 | 9c579f6a5966224304f0aac0b78dc796 |
|
BLAKE2b-256 | bdefb765f52f119a6845ea06086a563f436d9d75762fb767503e962ae09fc4c1 |
Hashes for torch_geometric_median-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f534acef41998f8120d2bc5d7cbb36522cc724da10e14db46dc060ac98e1df6 |
|
MD5 | 4b85f1b46d5b1c77189b19aef1eb246e |
|
BLAKE2b-256 | 357dcf3050ff3cb049d619e341636f3dca6b03a65b6f9569644962a06dfd8017 |