Skip to main content

Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more

Project description

Python Package Index Documentation Status Build Status Coverage Status Codestyle Black Gitter

Manifold aware pytorch.optim.

Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more.

Installation

Make sure you have pytorch>=1.10.2 installed

There are two ways to install geoopt:

  1. GitHub (preferred so far) due to active development

pip install git+https://github.com/geoopt/geoopt.git
  1. pypi (this might be significantly behind master branch)

pip install geoopt

The preferred way to install geoopt will change once stable project stage is achieved. Now, pypi is behind master as we actively develop and implement new features.

PyTorch Support

Geoopt officially supports 2 latest stable versions of pytorch upstream or the latest major release.

What is done so far

Work is in progress but you can already use this. Note that API might change in future releases.

Tensors

  • geoopt.ManifoldTensor - just as torch.Tensor with additional manifold keyword argument.

  • geoopt.ManifoldParameter - same as above, recognized in torch.nn.Module.parameters as correctly subclassed.

All above containers have special methods to work with them as with points on a certain manifold

  • .proj_() - inplace projection on the manifold.

  • .proju(u) - project vector u on the tangent space. You need to project all vectors for all methods below.

  • .egrad2rgrad(u) - project gradient u on Riemannian manifold

  • .inner(u, v=None) - inner product at this point for two tangent vectors at this point. The passed vectors are not projected, they are assumed to be already projected.

  • .retr(u) - retraction map following vector u

  • .expmap(u) - exponential map following vector u (if expmap is not available in closed form, best approximation is used)

  • .transp(v, u) - transport vector v with direction u

  • .retr_transp(v, u) - transport self, vector v (and possibly more vectors) with direction u (returns are plain tensors)

Manifolds

  • geoopt.Euclidean - unconstrained manifold in R with Euclidean metric

  • geoopt.Stiefel - Stiefel manifold on matrices A in R^{n x p} : A^t A=I, n >= p

  • geoopt.Sphere - Sphere manifold ||x||=1

  • geoopt.BirkhoffPolytope - manifold of Doubly Stochastic matrices

  • geoopt.Stereographic - Constant curvature stereographic projection model

  • geoopt.SphereProjection - Sphere stereographic projection model

  • geoopt.PoincareBall - Poincare ball model

  • geoopt.Lorentz - Hyperboloid model

  • geoopt.ProductManifold - Product manifold constructor

  • geoopt.Scaled - Scaled version of the manifold. Similar to Learning Mixed-Curvature Representations in Product Spaces if combined with ProductManifold

  • geoopt.SymmetricPositiveDefinite - SPD matrix manifold

  • geoopt.UpperHalf - Siegel Upper half manifold. Supports Riemannian and Finsler metrics, as in Symmetric Spaces for Graph Embeddings: A Finsler-Riemannian Approach.

  • geoopt.BoundedDomain - Siegel Bounded domain manifold. Supports Riemannian and Finsler metrics.

All manifolds implement methods necessary to manipulate tensors on manifolds and tangent vectors to be used in general purpose. See more in documentation.

Optimizers

  • geoopt.optim.RiemannianSGD - a subclass of torch.optim.SGD with the same API

  • geoopt.optim.RiemannianAdam - a subclass of torch.optim.Adam

Samplers

  • geoopt.samplers.RSGLD - Riemannian Stochastic Gradient Langevin Dynamics

  • geoopt.samplers.RHMC - Riemannian Hamiltonian Monte-Carlo

  • geoopt.samplers.SGRHMC - Stochastic Gradient Riemannian Hamiltonian Monte-Carlo

Layers

Experimental geoopt.layers module allows to embed geoopt into deep learning

Citing Geoopt

If you find this project useful in your research, please kindly add this bibtex entry in references and cite.

@misc{geoopt2020kochurov,
    title={Geoopt: Riemannian Optimization in PyTorch},
    author={Max Kochurov and Rasul Karimov and Serge Kozlukov},
    year={2020},
    eprint={2005.02819},
    archivePrefix={arXiv},
    primaryClass={cs.CG}
}

Donations

ETH: 0x008319973D4017414FdF5B3beF1369bA78275C6A

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

geoopt-0.5.0.tar.gz (63.6 kB view details)

Uploaded Source

Built Distribution

geoopt-0.5.0-py3-none-any.whl (90.1 kB view details)

Uploaded Python 3

File details

Details for the file geoopt-0.5.0.tar.gz.

File metadata

  • Download URL: geoopt-0.5.0.tar.gz
  • Upload date:
  • Size: 63.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.15

File hashes

Hashes for geoopt-0.5.0.tar.gz
Algorithm Hash digest
SHA256 b22694662ecf21b25e183e3daeb27c0139d6a90e1d566100f6e64d1b7d99d53d
MD5 ace870fa2a59968962be9203bd76b690
BLAKE2b-256 61260c3376c2a96b2a4be6e6f018ebf809e4e4f368d47218f74a8bfa07d80cd2

See more details on using hashes here.

File details

Details for the file geoopt-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: geoopt-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 90.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.64.0 importlib-metadata/4.8.3 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.15

File hashes

Hashes for geoopt-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9b5cf729d54f43de5085be14f20cc37050c42e75d09cc7bc3abdddbffddb7209
MD5 9ca3af644bef9f4d866ddfba21eeefe1
BLAKE2b-256 0186acd9dfcbba44e3e95216939dab3e9a32f586f94086885bca997126809dae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page