Skip to main content

SQUAREM accelerator for JAXopt solvers

Project description

PyPI version CI CD

squarem-JAXopt

squarem-JAXopt is an implementation of the SQUAREM accelerator for solving fixed-point equations, see Du and Varadhan (2020). SQUAREM is implemented in JAX and JAXopt. The later allow for implicit differentiation of the fixed-point.

Installation

pip install squarem-jaxopt

Usage

import jax
import jax.numpy as jnp
from jax import random

from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration

# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)

N = 4

a = random.uniform(random.PRNGKey(111), (N, 1))
b = random.uniform(random.PRNGKey(112), (1, 1))


def fun(x: jnp.ndarray) -> jnp.ndarray:
    y = a + x @ b
    return y


fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=True)
result_none = fxp_none.run(jnp.zeros_like(a))

fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=True)
result_anderson = fxp_anderson.run(jnp.zeros_like(a))

fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=True)
result_squarem = fxp_squarem.run(jnp.zeros_like(a))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

squarem_jaxopt-0.1.3.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

squarem_jaxopt-0.1.3-py3-none-any.whl (4.8 kB view details)

Uploaded Python 3

File details

Details for the file squarem_jaxopt-0.1.3.tar.gz.

File metadata

  • Download URL: squarem_jaxopt-0.1.3.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.8.17

File hashes

Hashes for squarem_jaxopt-0.1.3.tar.gz
Algorithm Hash digest
SHA256 cf04761a85ac6e40d56e8e113c2918030067389897e9aeca13d6461757ff6c3e
MD5 079e4f84dfcf53724f0e34bc3ef0b8e6
BLAKE2b-256 db1e948607d145f3bd3f32d826862ac7933a38a8d9ae835406ae63cbcb78bb0a

See more details on using hashes here.

File details

Details for the file squarem_jaxopt-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for squarem_jaxopt-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 46a95f64b3f19caab4d306a1162b496cd432ed268410aac88fbf0c164974a36d
MD5 dae3ba5a1c26668e8146b6761983d0ce
BLAKE2b-256 8cf6e5abe3a3aa0f7c3500538dc29c1b7b087a28126e0dbd3450fac5d31fed51

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page