Skip to main content

SQUAREM accelerator for JAXopt solvers

Project description

CI CD

squarem-JAXopt

squarem-JAXopt is an implementation of the SQUAREM accelerator for solving fixed-point equations, see Du and Varadhan (2020). SQUAREM is implemented in JAX and JAXopt. The later allow for implicit differentiation of the fixed-point.

Installation

pip install squarem-JAXopt

Usage

import jax
import jax.numpy as jnp
from jax import random

from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration

# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)

N = 4

a = random.uniform(random.PRNGKey(111), (N, 1))
b = random.uniform(random.PRNGKey(112), (1, 1))


def fun(x: jnp.ndarray) -> jnp.ndarray:
    y = a + x @ b
    return y


fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=True)
result_none = fxp_none.run(jnp.zeros_like(a))

fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=True)
result_anderson = fxp_anderson.run(jnp.zeros_like(a))

fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=True)
result_squarem = fxp_squarem.run(jnp.zeros_like(a))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

squarem_jaxopt-0.1.2.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

squarem_jaxopt-0.1.2-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file squarem_jaxopt-0.1.2.tar.gz.

File metadata

  • Download URL: squarem_jaxopt-0.1.2.tar.gz
  • Upload date:
  • Size: 4.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.8.16

File hashes

Hashes for squarem_jaxopt-0.1.2.tar.gz
Algorithm Hash digest
SHA256 a7a4559e733107112672fa3c9ce109434db634d52bd62d212ba605f628624467
MD5 ecd9fcb6d6c6cc1431e3b28602cdce41
BLAKE2b-256 dc0b19fd00b9880dbae95191db58337e3d4b849382c2d8c86ae738b8d9cc22c1

See more details on using hashes here.

File details

Details for the file squarem_jaxopt-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for squarem_jaxopt-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a1cf298a5a984f17b71bf06f8ba2cc59edd310e64e45e06960c6768456499524
MD5 b7646585cb10e533121aa4ae47433235
BLAKE2b-256 32e5b05d07d6938ba260b699fbe37d5ff7ec6e71d1ef44f3cf1a55c0c7044caf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page