pytorch-optimized geometric median implementation
Project description
torch-geometric-median
A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!
Installation
pip install torch-geometric-median
Usage
This library exports a single function, geometric_median
, which takes a tensor of shape (N, D)
where N
is the number of samples, and D
is the size of each sample, and returns the geometric median of the points in the tensor .
from torch_geometric_median import geometric_median
# Create a tensor of points
points = torch.tensor([
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
])
# Compute the geometric median
median = geometric_median(points).median
Backprop
Like the original geom-median library, this library supports backpropagation through the geometric median computation.
median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`
Extra options
The geometric_median
function also supports a few extra options:
maxiter
: The maximum number of iterations to run the optimization for. Default is 100.ftol
: If objective value does not improve by at least thisftol
fraction, terminate the algorithm. Default 1e-20.weights
: A tensor of shape(N,)
containing the weights for each point, whereN
is the number of samples. Default isNone
, which means all points are weighted equally.show_progress
: IfTrue
, show a progress bar for the optimization. Default isFalse
.log_objective_values
: IfTrue
, log the objective value at each iteration under the keyobjective_values_log
. Default isFalse
.
median = geometric_median(
points,
maxiter=1000,
ftol=1e-10,
weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
show_progress=True,
log_objective_values=True
).median
Why does this library exist?
It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.
Acknowledgements
This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.
License
This library is licensed under a GPL license, as per the original geom-median library.
Contributing
Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_geometric_median-0.1.1.tar.gz
.
File metadata
- Download URL: torch_geometric_median-0.1.1.tar.gz
- Upload date:
- Size: 16.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 00316a0237471d0628516db407a34fc99e04504b62575fd1c03670e839aba2c0 |
|
MD5 | 9a025a8606fb4454f410a45de193d01b |
|
BLAKE2b-256 | 781fa593a3f2f04ed90ca48b6bf4d614d52cb502cd50db6c4925a5e579ab2d16 |
File details
Details for the file torch_geometric_median-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: torch_geometric_median-0.1.1-py3-none-any.whl
- Upload date:
- Size: 16.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 20c26ef7a40d3cc3f6f29fa4cb9f78eebf56e8c7e702f4fb8e2d8b2910b3c386 |
|
MD5 | 768752594d26ee938beb0af5be830f75 |
|
BLAKE2b-256 | 06f3b70c4295c91e1ef306a6d9b30a028072cf7e0f6bab29e4339a6cd2299fe2 |