SQUAREM accelerator for JAXopt solvers
Project description
squarem-JAXopt
squarem-JAXopt is an implementation of the SQUAREM accelerator for solving fixed-point equations, see Du and Varadhan (2020). SQUAREM is implemented in JAX and JAXopt. The later allow for implicit differentiation of the fixed-point.
Installation
pip install squarem-JAXopt
Usage
import jax
import jax.numpy as jnp
from jax import random
from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration
# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)
N = 4
a = random.uniform(random.PRNGKey(111), (N, 1))
b = random.uniform(random.PRNGKey(112), (1, 1))
def fun(x: jnp.ndarray) -> jnp.ndarray:
y = a + x @ b
return y
fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=True)
result_none = fxp_none.run(jnp.zeros_like(a))
fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=True)
result_anderson = fxp_anderson.run(jnp.zeros_like(a))
fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=True)
result_squarem = fxp_squarem.run(jnp.zeros_like(a))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
squarem_jaxopt-0.1.2.tar.gz
(4.5 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file squarem_jaxopt-0.1.2.tar.gz.
File metadata
- Download URL: squarem_jaxopt-0.1.2.tar.gz
- Upload date:
- Size: 4.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a7a4559e733107112672fa3c9ce109434db634d52bd62d212ba605f628624467
|
|
| MD5 |
ecd9fcb6d6c6cc1431e3b28602cdce41
|
|
| BLAKE2b-256 |
dc0b19fd00b9880dbae95191db58337e3d4b849382c2d8c86ae738b8d9cc22c1
|
File details
Details for the file squarem_jaxopt-0.1.2-py3-none-any.whl.
File metadata
- Download URL: squarem_jaxopt-0.1.2-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a1cf298a5a984f17b71bf06f8ba2cc59edd310e64e45e06960c6768456499524
|
|
| MD5 |
b7646585cb10e533121aa4ae47433235
|
|
| BLAKE2b-256 |
32e5b05d07d6938ba260b699fbe37d5ff7ec6e71d1ef44f3cf1a55c0c7044caf
|