pytorch-optimized geometric median implementation
Project description
torch-geometric-median
A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!
Installation
pip install torch-geometric-median
Usage
This library exports a single function, geometric_median
, which takes a tensor of shape (N, D)
where N
is the number of samples, and D
is the size of each sample, and returns the geometric median of the points in the tensor .
from torch_geometric_median import geometric_median
# Create a tensor of points
points = torch.tensor([
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
])
# Compute the geometric median
median = geometric_median(points).median
Backprop
Like the original geom-median library, this library supports backpropagation through the geometric median computation.
median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`
Extra options
The geometric_median
function also supports a few extra options:
maxiter
: The maximum number of iterations to run the optimization for. Default is 100.ftol
: If objective value does not improve by at least thisftol
fraction, terminate the algorithm. Default 1e-20.weights
: A tensor of shape(N,)
containing the weights for each point, whereN
is the number of samples. Default isNone
, which means all points are weighted equally.show_progress
: IfTrue
, show a progress bar for the optimization. Default isFalse
.log_objective_values
: IfTrue
, log the objective value at each iteration under the keyobjective_values_log
. Default isFalse
.
median = geometric_median(
points,
maxiter=1000,
ftol=1e-10,
weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
show_progress=True,
log_objective_values=True
).median
Why does this library exist?
It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.
Acknowledgements
This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.
License
This library is licensed under a GPL license, as per the original geom-median library.
Contributing
Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_geometric_median-0.1.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 00316a0237471d0628516db407a34fc99e04504b62575fd1c03670e839aba2c0 |
|
MD5 | 9a025a8606fb4454f410a45de193d01b |
|
BLAKE2b-256 | 781fa593a3f2f04ed90ca48b6bf4d614d52cb502cd50db6c4925a5e579ab2d16 |
Hashes for torch_geometric_median-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 20c26ef7a40d3cc3f6f29fa4cb9f78eebf56e8c7e702f4fb8e2d8b2910b3c386 |
|
MD5 | 768752594d26ee938beb0af5be830f75 |
|
BLAKE2b-256 | 06f3b70c4295c91e1ef306a6d9b30a028072cf7e0f6bab29e4339a6cd2299fe2 |