Constrained Optimization and Manifold Optimization in Pytorch
Project description
A library for constrained optimization and manifold optimization for deep learning in PyTorch
Overview
GeoTorch provides a simple way to perform constrained optimization and optimization on manifolds in PyTorch. It is compatible out of the box with any optimizer, layer, and model implemented in PyTorch without any boilerplate in the training code. Just state the constraints when you construct the model and you are ready to go!
import torch
import torch.nn as nn
import geotorch
class Model(nn.Module):
def __init__(self):
super().__init__()
# One line suffices: Make a linear layer with orthonormal columns
self.linear = nn.Linear(64, 128)
geotorch.orthogonal(self.linear, "weight")
# Works with tensors: Make a CNN with kernels of rank 1
self.cnn = nn.Conv2d(16, 32, 3)
geotorch.low_rank(self.cnn, "weight", rank=1)
# Weights are initialized to a random value when you put the constraints, but
# you may re-initialize them to a different value by assigning to them
self.linear.weight = torch.eye(128, 64)
# And that's all you need to do. The rest is regular PyTorch code
def forward(self, x):
# self.linear is orthogonal and every 3x3 kernel in self.cnn is of rank 1
# Use the model as you would normally do. Everything just works
model = Model().cuda()
# Use your optimizer of choice. Any optimizer works out of the box with any parametrization
optim = torch.optim.Adam(model.parameters(), lr=lr)
Constraints
The following constraints are implemented and may be used as in the example above:
geotorch.symmetric. Symmetric matrices
geotorch.skew. Skew-symmetric matrices
geotorch.sphere. Vectors of norm 1
geotorch.orthogonal. Matrices with orthogonal columns
geotorch.grassmannian. Skew-symmetric matrices
geotorch.almost_orthogonal(λ). Matrices with singular values in the interval [1-λ, 1+λ]
geotorch.invertible. Invertible matrices with positive determinant
geotorch.low_rank(r). Matrices of rank at most r
geotorch.fixed_rank(r). Matrices of rank r
geotorch.positive_definite. Positive definite matrices
geotorch.positive_semidefinite. Positive semidefinite matrices
geotorch.positive_semidefinite_low_rank(r). Positive semidefinite matrices of rank at most r
geotorch.positive_semidefinite_fixed_rank(r). Positive semidefinite matrices of rank r
Each of these constraints have some extra parameters which can be used to tailor the behavior of each constraint to the problem in hand. For more on this, see the documentation.
These constraints are a fronted for the families of spaces listed below.
Supported Spaces
Each constraint in GeoTorch is implemented as a manifold. These give the user more flexibility on the options that they choose for each parametrization. All these support Riemannian Gradient Descent by default (more on this here), but they also support optimization via any other PyTorch optimizer.
GeoTorch currently supports the following spaces:
Rn(n): Rⁿ. Unrestricted optimization
Sym(n): Vector space of symmetric matrices
Skew(n): Vector space of skew-symmetric matrices
Sphere(n): Sphere in Rⁿ. { x ∈ Rⁿ | ||x|| = 1 } ⊂ Rⁿ
SO(n): Manifold of n×n orthogonal matrices
St(n,k): Manifold of n×k matrices with orthonormal columns
AlmostOrthogonal(n,k,λ): Manifold of n×k matrices with singular values in the interval [1-λ, 1+λ]
Gr(n,k): Manifold of k-dimensional subspaces in Rⁿ
GLp(n): Manifold of invertible n×n matrices with positive determinant
LowRank(n,k,r): Variety of n×k matrices of rank r or less
FixedRank(n,k,r): Manifold of n×k matrices of rank r
PSD(n): Cone of n×n symmetric positive definite matrices
PSSD(n): Cone of n×n symmetric positive semi-definite matrices
PSSDLowRank(n,r): Variety of n×n symmetric positive semi-definite matrices of rank r or less
PSSDFixedRank(n,r): Manifold of n×n symmetric positive semi-definite matrices of rank r
ProductManifold(M₁, ..., Mₖ): Product of manifolds M₁ × ... × Mₖ
Every space of dimension (n, k) can be applied to tensors of shape (*, n, k), so we also get efficient parallel implementations of product spaces such as
ObliqueManifold(n,k): Matrix with unit length columns, Sⁿ⁻¹ × ...ᵏ⁾ × Sⁿ⁻¹
Using GeoTorch in your Code
The files in examples/copying_problem.py and examples/sequential_mnist.py serve as tutorials to see how to handle the initialization and usage of GeoTorch in some real code. They also show how to implement Riemannian Gradient Descent and some other tricks. For an introduction to how the library is actually implemented, see the Jupyter Notebook examples/parametrisations.ipynb.
You may try GeoTorch installing it with
pip install git+https://github.com/Lezcano/geotorch/
GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6 and supports PyTorch >= 1.9
Bibliography
Please cite the following work if you found GeoTorch useful. This paper exposes a simplified mathematical explanation of part of the inner-workings of GeoTorch.
@inproceedings{lezcano2019trivializations,
title = {Trivializations for gradient-based optimization on manifolds},
author = {Lezcano-Casado, Mario},
booktitle={Advances in Neural Information Processing Systems, NeurIPS},
pages = {9154--9164},
year = {2019},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file geotorch-0.3.0.tar.gz
.
File metadata
- Download URL: geotorch-0.3.0.tar.gz
- Upload date:
- Size: 41.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fe99106d90667a93288d7e2ce349d5757475320332ccc058af270c043ed2c3f4 |
|
MD5 | 5eb6803515f7e2a2da02b2c519aaa869 |
|
BLAKE2b-256 | 80400e8e34ed13c431676989e46b1f4a7dd1f1689c98a59269a403a662b81f35 |
File details
Details for the file geotorch-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: geotorch-0.3.0-py3-none-any.whl
- Upload date:
- Size: 54.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c7028a4ecc0ddc8602e9f3fe9351389cff2dcb12f6eccbcbb268731196bf8cff |
|
MD5 | 5e5c113dafaa2de9864fb0a63d213a7e |
|
BLAKE2b-256 | d5903b05a65c4399eee2c27b0f4d72b54fe3e0ad1a5f7a8fa9341c239309c090 |