JAX-native Riemannian manifold optimization
Project description
RiemannAX
Hardware-accelerated Riemannian Manifold Optimization with JAX
Overview
RiemannAX is a high-performance library for optimization on Riemannian manifolds, built upon JAX's ecosystem. It provides mathematically rigorous implementations of manifold structures and optimization algorithms, leveraging automatic differentiation, just-in-time compilation, and hardware acceleration to deliver exceptional computational efficiency for geometric optimization problems.
The library bridges the gap between theoretical differential geometry and practical machine learning applications, enabling researchers and practitioners to solve complex optimization problems that arise in computer vision, machine learning, and scientific computing.
Key Features
🔬 Comprehensive Manifold Library
- Sphere (
S^n): Unit hypersphere with geodesic operations - Special Orthogonal Group (
SO(n)): Rotation matrices with Lie group structure - Grassmann Manifold (
Gr(p,n)): Subspace optimization for dimensionality reduction and principal component analysis - Stiefel Manifold (
St(p,n)): Orthonormal frames with applications in orthogonal Procrustes problems - Rigorous implementations with validation, batch operations, and numerical stability
⚡ High-Performance Optimization
- Riemannian Gradient Descent: First-order optimization with exponential maps and retractions
- Automatic Differentiation: Seamless computation of Riemannian gradients from Euclidean cost functions
- Hardware Acceleration: GPU/TPU support through JAX's XLA compilation
- Batch Processing: Vectorized operations for multiple optimization instances
🛠 Robust Framework
- Flexible Problem Definition: Support for custom cost functions and gradients
- Comprehensive Validation: Manifold constraint verification and numerical stability checks
- Extensive Testing: 77+ unit and integration tests ensuring mathematical correctness
- Type Safety: Full type annotations for Python 3.10+ compatibility
Installation
Standard Installation
pip install riemannax
Development Installation
git clone https://github.com/lv416e/riemannax.git
cd riemannax
pip install -e ".[dev]"
With UV Package Manager
uv venv && source .venv/bin/activate
uv pip install -e .
Quick Start Examples
Sphere Optimization: Finding Optimal Directions
import jax
import jax.numpy as jnp
import riemannax as rx
# Define the unit sphere manifold
sphere = rx.Sphere()
# Optimization problem: find point closest to target direction
target = jnp.array([0., 0., 1.]) # North pole
def cost_fn(x):
return -jnp.dot(x, target)
problem = rx.RiemannianProblem(sphere, cost_fn)
# Initialize and solve
key = jax.random.key(42)
x0 = sphere.random_point(key)
result = rx.minimize(problem, x0, method='rsgd',
options={'learning_rate': 0.1, 'max_iterations': 100})
print(f"Optimal point: {result.x}")
print(f"Final cost: {result.fun:.6f}")
Grassmann Manifold: Subspace Fitting
import jax
import jax.numpy as jnp
import riemannax as rx
# Generate synthetic data in a 3D subspace of 8D space
key = jax.random.key(123)
n, p, m = 8, 3, 100
true_subspace = rx.Grassmann(n, p).random_point(key)
# Create noisy data
keys = jax.random.split(key, 3)
coeffs = jax.random.normal(keys[0], (p, m))
noise = 0.1 * jax.random.normal(keys[1], (n, m))
data = true_subspace @ coeffs + noise
# Define subspace fitting problem
def subspace_cost(x):
projector = x @ x.T
reconstruction = projector @ data
return jnp.sum((data - reconstruction) ** 2)
# Optimize on Grassmann manifold
manifold = rx.Grassmann(n, p)
problem = rx.RiemannianProblem(manifold, subspace_cost)
x0 = manifold.random_point(keys[2])
result = rx.minimize(problem, x0, method='rsgd',
options={'learning_rate': 0.01, 'max_iterations': 200})
print(f"Reconstruction error: {result.fun:.6f}")
Stiefel Manifold: Orthogonal Procrustes Problem
import jax
import jax.numpy as jnp
import riemannax as rx
# Setup Procrustes problem: find optimal orthogonal transformation
key = jax.random.key(789)
n, p = 6, 4
keys = jax.random.split(key, 3)
A = jax.random.normal(keys[0], (n, p))
B = jax.random.normal(keys[1], (n, p))
# Minimize ||A - BQ||_F^2 over orthogonal matrices Q
def procrustes_cost(Q):
return jnp.sum((A - B @ Q) ** 2)
# Optimize on Stiefel manifold (orthogonal group)
manifold = rx.Stiefel(p, p)
problem = rx.RiemannianProblem(manifold, procrustes_cost)
x0 = manifold.random_point(keys[2])
result = rx.minimize(problem, x0, method='rsgd',
options={'learning_rate': 0.1, 'max_iterations': 100})
print(f"Procrustes cost: {result.fun:.6f}")
print(f"Orthogonality check: {jnp.allclose(result.x.T @ result.x, jnp.eye(p))}")
Advanced Usage
Custom Gradient Functions
# Define Euclidean gradient for automatic projection
def euclidean_grad(x):
return jax.grad(cost_fn)(x)
problem = rx.RiemannianProblem(manifold, cost_fn, euclidean_grad_fn=euclidean_grad)
Batch Optimization
# Optimize multiple instances simultaneously
batch_size = 10
x0_batch = manifold.random_point(key, batch_size)
# Vectorized cost function
def batch_cost(x_batch):
return jax.vmap(cost_fn)(x_batch)
batch_problem = rx.RiemannianProblem(manifold, batch_cost)
Exponential Map vs. Retraction
# Use exponential map for geodesically exact optimization
result_exp = rx.minimize(problem, x0, method='rsgd', use_retraction=False)
# Use retraction for computational efficiency
result_retr = rx.minimize(problem, x0, method='rsgd', use_retraction=True)
Comprehensive Examples
Explore detailed implementations in the examples/ directory:
sphere_optimization_demo.py: Sphere optimization with visualizationgrassmann_optimization_demo.py: Subspace fitting and principal angles analysisstiefel_optimization_demo.py: Orthogonal Procrustes with multiple exponential map methodsmanifolds_comparison_demo.py: Comparative analysis across all manifoldsnotebooks/: Interactive Jupyter notebooks with step-by-step tutorials
Testing and Development
Running Tests
# Quick test suite
make test
# With coverage analysis
make coverage
# Specific test categories
pytest tests/manifolds/ # Manifold implementations
pytest tests/optimizers/ # Optimization algorithms
pytest tests/integration/ # End-to-end workflows
Development Workflow
# Install development dependencies
pip install -e ".[dev]"
# Code formatting and linting
make format
make lint
# Type checking
make typecheck
# Documentation building
make docs
Performance Characteristics
RiemannAX leverages JAX's XLA compilation for exceptional performance:
- GPU Acceleration: Automatic device placement and parallel execution
- JIT Compilation: First-call compilation overhead with subsequent near-C performance
- Memory Efficiency: In-place operations and optimized memory layouts
- Batch Processing: Vectorized operations across multiple problem instances
Typical performance improvements over CPU-based alternatives:
- 10-100x speedup on GPU for large-scale problems
- 2-5x speedup on CPU through XLA optimization
- Linear scaling with batch size for parallel optimization
Mathematical Foundation
RiemannAX implements manifolds with rigorous differential geometric operations:
Manifold Interface
Each manifold provides:
- Exponential Map (
exp): Geodesic curves from tangent vectors - Logarithmic Map (
log): Inverse of exponential map - Retraction (
retr): Computationally efficient approximation to exponential map - Parallel Transport (
transp): Moving tangent vectors along manifold - Riemannian Metric (
inner): Tangent space inner products - Projection (
proj): Orthogonal projection onto tangent space
Numerical Stability
- Robust QR-based orthogonalization for Stiefel and Grassmann manifolds
- Numerically stable distance computations using principal angles
- Careful handling of edge cases and degenerate configurations
- Comprehensive validation with appropriate floating-point tolerances
Contributing
We welcome contributions! Please see our Contributing Guidelines for details on:
- Development setup and workflow
- Code style and testing requirements
- Documentation standards
- Pull request process
Development Priorities
- Additional manifold implementations (Hyperbolic, Product manifolds)
- Advanced optimization algorithms (Conjugate Gradient, L-BFGS)
- Enhanced visualization and debugging tools
- Performance optimizations and benchmarking
Citation
If you use RiemannAX in your research, please cite:
@software{riemannax2024,
title={RiemannAX: Hardware-accelerated Riemannian Manifold Optimization with JAX},
author={mary},
year={2024},
url={https://github.com/lv416e/riemannax}
}
License
Licensed under the Apache License 2.0. See LICENSE for details.
Acknowledgments
RiemannAX draws inspiration from:
- JAX: Functional programming and automatic differentiation paradigms
- Optax: Optimization algorithm design patterns
- Pymanopt: Comprehensive Riemannian optimization reference
- Geoopt: PyTorch-based Riemannian optimization library
Special thanks to the JAX development team for creating an exceptional foundation for scientific computing.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file riemannax-0.0.2.tar.gz.
File metadata
- Download URL: riemannax-0.0.2.tar.gz
- Upload date:
- Size: 28.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19ded9145c90e736bf0cd5b23b25fd3a05974ab3f852b80e782248eec5f7b07e
|
|
| MD5 |
96f5f9723421175960854680c85bfebe
|
|
| BLAKE2b-256 |
3c60d7bcac25061355df9f315cbd227feb8dd1e06119ee0a82ff0d666562103a
|
Provenance
The following attestation bundles were made for riemannax-0.0.2.tar.gz:
Publisher:
release.yml on lv416e/riemannax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
riemannax-0.0.2.tar.gz -
Subject digest:
19ded9145c90e736bf0cd5b23b25fd3a05974ab3f852b80e782248eec5f7b07e - Sigstore transparency entry: 250316856
- Sigstore integration time:
-
Permalink:
lv416e/riemannax@062e5d9dca87febc9f4ec649316dd34544c0013b -
Branch / Tag:
refs/tags/v0.0.2 - Owner: https://github.com/lv416e
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@062e5d9dca87febc9f4ec649316dd34544c0013b -
Trigger Event:
push
-
Statement type:
File details
Details for the file riemannax-0.0.2-py3-none-any.whl.
File metadata
- Download URL: riemannax-0.0.2-py3-none-any.whl
- Upload date:
- Size: 26.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
31973e7baf59485e11d41a2fd314c42087e43b5740cd8c8f239e384f4117c736
|
|
| MD5 |
6b3087615b1683c41510f5b365f15aa0
|
|
| BLAKE2b-256 |
ddf00fee3d63150220020edce5870429e11a7d4f69561d10895ac1b324419cb2
|
Provenance
The following attestation bundles were made for riemannax-0.0.2-py3-none-any.whl:
Publisher:
release.yml on lv416e/riemannax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
riemannax-0.0.2-py3-none-any.whl -
Subject digest:
31973e7baf59485e11d41a2fd314c42087e43b5740cd8c8f239e384f4117c736 - Sigstore transparency entry: 250316878
- Sigstore integration time:
-
Permalink:
lv416e/riemannax@062e5d9dca87febc9f4ec649316dd34544c0013b -
Branch / Tag:
refs/tags/v0.0.2 - Owner: https://github.com/lv416e
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@062e5d9dca87febc9f4ec649316dd34544c0013b -
Trigger Event:
push
-
Statement type: