Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more
Project description
Manifold aware pytorch.optim.
Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more.
What is done so far
Work is in progress but you can already use this. Note that API might change in future releases.
Tensors
geoopt.ManifoldTensor – just as torch.Tensor with additional manifold keyword argument.
geoopt.ManifoldParameter – same as above, recognized in torch.nn.Module.parameters as correctly subclassed.
All above containers have special methods to work with them as with points on a certain manifold
.proj_() – inplace projection on the manifold.
.proju(u) – project vector u on the tangent space. You need to project all vectors for all methods below.
.inner(u, v=None) – inner product at this point for two tangent vectors at this point. The passed vectors are not projected, they are assumed to be already projected.
.retr(u, t) – retraction map following vector u for time t
.transp(u, t, v, *more) – transport vector v (and possibly more vectors) with direction u for time t
.retr_transp(u, t, v, *more) – transport self, vector v (and possibly more vectors) with direction u for time t (returns are plain tensors)
Manifolds
geoopt.Euclidean – unconstrained manifold in R with Euclidean metric
geoopt.Stiefel – Stiefel manifold on matrices A in R^{n x p} : A^t A=I, n >= p
Optimizers
geoopt.optim.RiemannianSGD – a subclass of torch.optim.SGD with the same API
geoopt.optim.RiemannianAdam – a subclass of torch.optim.Adam
Samplers
geoopt.samplers.RSGLD – Riemannian Stochastic Gradient Langevin Dynamics
geoopt.samplers.RHMC – Riemannian Hamiltonian Monte-Carlo
geoopt.samplers.SGRHMC – Stochastic Gradient Riemannian Hamiltonian Monte-Carlo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for geoopt-0.0.1rc2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9217829f1d6eff6d964c476632173a0998993e6e012062f50531404752c27043 |
|
MD5 | 46fd99ba5a9a9054939fec90106c0016 |
|
BLAKE2b-256 | 0fd9f25529555f986d04e986141c0f1ac033811f6e340625b8de8c8bfaa4ead8 |