Skip to main content

SQUAREM accelerator for JAXopt solvers

Project description

PyPI version Documentation CI CD

squarem-JAXopt

JAX implementation of the SQUAREM accelerator for solving fixed-point equations, originally proposed by Du and Varadhan (2020).

The SQUAREM accelerator is implemented using JAXopt, enabling efficient automatic differentiation of the fixed-point equations via the implicit function theorem (see Blondel et al., 2022 for details).

Installation

pip install squarem-jaxopt

Usage

import jax
import jax.numpy as jnp
from jax import random

from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration

# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)

N = 100_000

a = random.uniform(random.PRNGKey(111), (N, 1))


def fun(x: jax.Array) -> jax.Array:
    y = (1 - a) + a * jnp.cos(x)
    return y


initial_guess = jnp.zeros_like(a)

fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=False)
result_none = fxp_none.run(initial_guess)

fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=False)
result_anderson = fxp_anderson.run(initial_guess)

fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=False)
result_squarem = fxp_squarem.run(initial_guess)

print("\n" + "="*60)
print("ALGORITHM COMPARISON TABLE")
print("="*60)
print(f"{'Algorithm':<25} {'Iterations':<12} {'Func Evals':<12} {'Error':<12}")
print("-"*60)
print(f"{'FixedPointIteration':<25} {result_none.state.iter_num:<12} {result_none.state.num_fun_eval:<12} {result_none.state.error:<12.2e}")
print(f"{'AndersonAcceleration':<25} {result_anderson.state.iter_num:<12} {result_anderson.state.num_fun_eval:<12} {result_anderson.state.error:<12.2e}")
print(f"{'SquaremAcceleration':<25} {result_squarem.state.iter_num:<12} {result_squarem.state.num_fun_eval:<12} {result_squarem.state.error:<12.2e}")
print("="*60)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

squarem_jaxopt-0.1.4.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

squarem_jaxopt-0.1.4-py3-none-any.whl (5.4 kB view details)

Uploaded Python 3

File details

Details for the file squarem_jaxopt-0.1.4.tar.gz.

File metadata

  • Download URL: squarem_jaxopt-0.1.4.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.7

File hashes

Hashes for squarem_jaxopt-0.1.4.tar.gz
Algorithm Hash digest
SHA256 75f86e608d1f10e5e44c8852d78b4828ae7aacc946176c0157bfefda4c33ce11
MD5 bea1937d37e832c95ce143cc2ce0dc0a
BLAKE2b-256 803456f439ff25c8ec319062e9f6622d8eabe76de70bdfd3e04cb88478e95033

See more details on using hashes here.

File details

Details for the file squarem_jaxopt-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for squarem_jaxopt-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 e3dfc1aa35f895fb8faf987ab3e0477bbf5fda782873de9eff0159d2873cde4f
MD5 5d4ed1bf001486d4766b183b1eff2997
BLAKE2b-256 07d295cb53cf066c290663803dea384efb6bc34679a384f720f492b44716c094

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page