SQUAREM accelerator for JAXopt solvers
Project description
squarem-JAXopt
squarem-JAXopt is an implementation of the SQUAREM accelerator for solving fixed-point equations, see Du and Varadhan (2020). SQUAREM is implemented in JAX and JAXopt. The later allow for implicit differentiation of the fixed-point.
Installation
pip install squarem-jaxopt
Usage
import jax
import jax.numpy as jnp
from jax import random
from jaxopt import FixedPointIteration, AndersonAcceleration
from squarem_jaxopt import SquaremAcceleration
# Increase precision to 64 bit
jax.config.update("jax_enable_x64", True)
N = 4
a = random.uniform(random.PRNGKey(111), (N, 1))
b = random.uniform(random.PRNGKey(112), (1, 1))
def fun(x: jnp.ndarray) -> jnp.ndarray:
y = a + x @ b
return y
fxp_none = FixedPointIteration(fixed_point_fun=fun, verbose=True)
result_none = fxp_none.run(jnp.zeros_like(a))
fxp_anderson = AndersonAcceleration(fixed_point_fun=fun, verbose=True)
result_anderson = fxp_anderson.run(jnp.zeros_like(a))
fxp_squarem = SquaremAcceleration(fixed_point_fun=fun, verbose=True)
result_squarem = fxp_squarem.run(jnp.zeros_like(a))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
squarem_jaxopt-0.1.3.tar.gz
(4.6 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file squarem_jaxopt-0.1.3.tar.gz.
File metadata
- Download URL: squarem_jaxopt-0.1.3.tar.gz
- Upload date:
- Size: 4.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cf04761a85ac6e40d56e8e113c2918030067389897e9aeca13d6461757ff6c3e
|
|
| MD5 |
079e4f84dfcf53724f0e34bc3ef0b8e6
|
|
| BLAKE2b-256 |
db1e948607d145f3bd3f32d826862ac7933a38a8d9ae835406ae63cbcb78bb0a
|
File details
Details for the file squarem_jaxopt-0.1.3-py3-none-any.whl.
File metadata
- Download URL: squarem_jaxopt-0.1.3-py3-none-any.whl
- Upload date:
- Size: 4.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
46a95f64b3f19caab4d306a1162b496cd432ed268410aac88fbf0c164974a36d
|
|
| MD5 |
dae3ba5a1c26668e8146b6761983d0ce
|
|
| BLAKE2b-256 |
8cf6e5abe3a3aa0f7c3500538dc29c1b7b087a28126e0dbd3450fac5d31fed51
|