Skip to main content

JAX-native Riemannian manifold optimization

Project description

RiemannAX

Hardware-accelerated Riemannian Manifold Optimization with JAX

License PyPI Tests Lint Docs Release

Overview

RiemannAX is a high-performance library for optimization on Riemannian manifolds, built upon JAX's ecosystem. It provides mathematically rigorous implementations of manifold structures and optimization algorithms, leveraging automatic differentiation, just-in-time compilation, and hardware acceleration to deliver exceptional computational efficiency for geometric optimization problems.

The library bridges the gap between theoretical differential geometry and practical machine learning applications, enabling researchers and practitioners to solve complex optimization problems that arise in computer vision, machine learning, and scientific computing.

Key Features

🔬 Comprehensive Manifold Library

  • Sphere (S^n): Unit hypersphere with geodesic operations
  • Special Orthogonal Group (SO(n)): Rotation matrices with Lie group structure
  • Grassmann Manifold (Gr(p,n)): Subspace optimization for dimensionality reduction and principal component analysis
  • Stiefel Manifold (St(p,n)): Orthonormal frames with applications in orthogonal Procrustes problems
  • Rigorous implementations with validation, batch operations, and numerical stability

High-Performance Optimization

  • Riemannian Gradient Descent: First-order optimization with exponential maps and retractions
  • Automatic Differentiation: Seamless computation of Riemannian gradients from Euclidean cost functions
  • Hardware Acceleration: GPU/TPU support through JAX's XLA compilation
  • Batch Processing: Vectorized operations for multiple optimization instances

🛠 Robust Framework

  • Flexible Problem Definition: Support for custom cost functions and gradients
  • Comprehensive Validation: Manifold constraint verification and numerical stability checks
  • Extensive Testing: 77+ unit and integration tests ensuring mathematical correctness
  • Type Safety: Full type annotations for Python 3.10+ compatibility

Installation

Standard Installation

pip install riemannax

Development Installation

git clone https://github.com/lv416e/riemannax.git
cd riemannax
pip install -e ".[dev]"

With UV Package Manager

uv venv && source .venv/bin/activate
uv pip install -e .

Quick Start Examples

Sphere Optimization: Finding Optimal Directions

import jax
import jax.numpy as jnp
import riemannax as rx

# Define the unit sphere manifold
sphere = rx.Sphere()

# Optimization problem: find point closest to target direction
target = jnp.array([0., 0., 1.])  # North pole
def cost_fn(x):
    return -jnp.dot(x, target)

problem = rx.RiemannianProblem(sphere, cost_fn)

# Initialize and solve
key = jax.random.key(42)
x0 = sphere.random_point(key)
result = rx.minimize(problem, x0, method='rsgd',
                    options={'learning_rate': 0.1, 'max_iterations': 100})

print(f"Optimal point: {result.x}")
print(f"Final cost: {result.fun:.6f}")

Grassmann Manifold: Subspace Fitting

import jax
import jax.numpy as jnp
import riemannax as rx

# Generate synthetic data in a 3D subspace of 8D space
key = jax.random.key(123)
n, p, m = 8, 3, 100
true_subspace = rx.Grassmann(n, p).random_point(key)

# Create noisy data
keys = jax.random.split(key, 3)
coeffs = jax.random.normal(keys[0], (p, m))
noise = 0.1 * jax.random.normal(keys[1], (n, m))
data = true_subspace @ coeffs + noise

# Define subspace fitting problem
def subspace_cost(x):
    projector = x @ x.T
    reconstruction = projector @ data
    return jnp.sum((data - reconstruction) ** 2)

# Optimize on Grassmann manifold
manifold = rx.Grassmann(n, p)
problem = rx.RiemannianProblem(manifold, subspace_cost)
x0 = manifold.random_point(keys[2])

result = rx.minimize(problem, x0, method='rsgd',
                    options={'learning_rate': 0.01, 'max_iterations': 200})

print(f"Reconstruction error: {result.fun:.6f}")

Stiefel Manifold: Orthogonal Procrustes Problem

import jax
import jax.numpy as jnp
import riemannax as rx

# Setup Procrustes problem: find optimal orthogonal transformation
key = jax.random.key(789)
n, p = 6, 4
keys = jax.random.split(key, 3)

A = jax.random.normal(keys[0], (n, p))
B = jax.random.normal(keys[1], (n, p))

# Minimize ||A - BQ||_F^2 over orthogonal matrices Q
def procrustes_cost(Q):
    return jnp.sum((A - B @ Q) ** 2)

# Optimize on Stiefel manifold (orthogonal group)
manifold = rx.Stiefel(p, p)
problem = rx.RiemannianProblem(manifold, procrustes_cost)
x0 = manifold.random_point(keys[2])

result = rx.minimize(problem, x0, method='rsgd',
                    options={'learning_rate': 0.1, 'max_iterations': 100})

print(f"Procrustes cost: {result.fun:.6f}")
print(f"Orthogonality check: {jnp.allclose(result.x.T @ result.x, jnp.eye(p))}")

Advanced Usage

Custom Gradient Functions

# Define Euclidean gradient for automatic projection
def euclidean_grad(x):
    return jax.grad(cost_fn)(x)

problem = rx.RiemannianProblem(manifold, cost_fn, euclidean_grad_fn=euclidean_grad)

Batch Optimization

# Optimize multiple instances simultaneously
batch_size = 10
x0_batch = manifold.random_point(key, batch_size)

# Vectorized cost function
def batch_cost(x_batch):
    return jax.vmap(cost_fn)(x_batch)

batch_problem = rx.RiemannianProblem(manifold, batch_cost)

Exponential Map vs. Retraction

# Use exponential map for geodesically exact optimization
result_exp = rx.minimize(problem, x0, method='rsgd', use_retraction=False)

# Use retraction for computational efficiency
result_retr = rx.minimize(problem, x0, method='rsgd', use_retraction=True)

Comprehensive Examples

Explore detailed implementations in the examples/ directory:

  • sphere_optimization_demo.py: Sphere optimization with visualization
  • grassmann_optimization_demo.py: Subspace fitting and principal angles analysis
  • stiefel_optimization_demo.py: Orthogonal Procrustes with multiple exponential map methods
  • manifolds_comparison_demo.py: Comparative analysis across all manifolds
  • notebooks/: Interactive Jupyter notebooks with step-by-step tutorials

Testing and Development

Running Tests

# Quick test suite
make test

# With coverage analysis
make coverage

# Specific test categories
pytest tests/manifolds/     # Manifold implementations
pytest tests/optimizers/    # Optimization algorithms
pytest tests/integration/   # End-to-end workflows

Development Workflow

# Install development dependencies
pip install -e ".[dev]"

# Code formatting and linting
make format
make lint

# Type checking
make typecheck

# Documentation building
make docs

Performance Characteristics

RiemannAX leverages JAX's XLA compilation for exceptional performance:

  • GPU Acceleration: Automatic device placement and parallel execution
  • JIT Compilation: First-call compilation overhead with subsequent near-C performance
  • Memory Efficiency: In-place operations and optimized memory layouts
  • Batch Processing: Vectorized operations across multiple problem instances

Typical performance improvements over CPU-based alternatives:

  • 10-100x speedup on GPU for large-scale problems
  • 2-5x speedup on CPU through XLA optimization
  • Linear scaling with batch size for parallel optimization

Mathematical Foundation

RiemannAX implements manifolds with rigorous differential geometric operations:

Manifold Interface

Each manifold provides:

  • Exponential Map (exp): Geodesic curves from tangent vectors
  • Logarithmic Map (log): Inverse of exponential map
  • Retraction (retr): Computationally efficient approximation to exponential map
  • Parallel Transport (transp): Moving tangent vectors along manifold
  • Riemannian Metric (inner): Tangent space inner products
  • Projection (proj): Orthogonal projection onto tangent space

Numerical Stability

  • Robust QR-based orthogonalization for Stiefel and Grassmann manifolds
  • Numerically stable distance computations using principal angles
  • Careful handling of edge cases and degenerate configurations
  • Comprehensive validation with appropriate floating-point tolerances

Contributing

We welcome contributions! Please see our Contributing Guidelines for details on:

  • Development setup and workflow
  • Code style and testing requirements
  • Documentation standards
  • Pull request process

Development Priorities

  • Additional manifold implementations (Hyperbolic, Product manifolds)
  • Advanced optimization algorithms (Conjugate Gradient, L-BFGS)
  • Enhanced visualization and debugging tools
  • Performance optimizations and benchmarking

Citation

If you use RiemannAX in your research, please cite:

@software{riemannax2024,
  title={RiemannAX: Hardware-accelerated Riemannian Manifold Optimization with JAX},
  author={mary},
  year={2024},
  url={https://github.com/lv416e/riemannax}
}

License

Licensed under the Apache License 2.0. See LICENSE for details.

Acknowledgments

RiemannAX draws inspiration from:

  • JAX: Functional programming and automatic differentiation paradigms
  • Optax: Optimization algorithm design patterns
  • Pymanopt: Comprehensive Riemannian optimization reference
  • Geoopt: PyTorch-based Riemannian optimization library

Special thanks to the JAX development team for creating an exceptional foundation for scientific computing.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

riemannax-0.0.2.tar.gz (28.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

riemannax-0.0.2-py3-none-any.whl (26.1 kB view details)

Uploaded Python 3

File details

Details for the file riemannax-0.0.2.tar.gz.

File metadata

  • Download URL: riemannax-0.0.2.tar.gz
  • Upload date:
  • Size: 28.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for riemannax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 19ded9145c90e736bf0cd5b23b25fd3a05974ab3f852b80e782248eec5f7b07e
MD5 96f5f9723421175960854680c85bfebe
BLAKE2b-256 3c60d7bcac25061355df9f315cbd227feb8dd1e06119ee0a82ff0d666562103a

See more details on using hashes here.

Provenance

The following attestation bundles were made for riemannax-0.0.2.tar.gz:

Publisher: release.yml on lv416e/riemannax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file riemannax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: riemannax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 26.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for riemannax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 31973e7baf59485e11d41a2fd314c42087e43b5740cd8c8f239e384f4117c736
MD5 6b3087615b1683c41510f5b365f15aa0
BLAKE2b-256 ddf00fee3d63150220020edce5870429e11a7d4f69561d10895ac1b324419cb2

See more details on using hashes here.

Provenance

The following attestation bundles were made for riemannax-0.0.2-py3-none-any.whl:

Publisher: release.yml on lv416e/riemannax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page