Fixed-point iterations for root finding implemented in JAX
Project description
Fixed-point solver
FixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX.
- Strives to be minimal
- Has no dependencies other than JAX
Installation
pip install FixedPointJAX
Usage
import jax.numpy as jnp
from jax import random
from FixedPointJAX import FixedPointRoot
# Define the logit probabilities
def my_logit(x, axis=0):
nominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
denominator = jnp.sum(nominator, axis=axis, keepdims=True)
return nominator / denominator
# Define the function for the fixed-point iteration
def my_fxp(x,s0):
s = my_logit(x)
z = jnp.log(s0 / s)
return x + z, z
print('-----------------------------------------')
# Dimensions of system of fixed-point equations
shape = (3, 4)
# Simulate probabilities
s0 = my_logit(random.uniform(key=random.PRNGKey(123), shape=shape))
# Set up fixed-point equation
fun = lambda x: my_fxp(x,s0)
# Initial guess
x0 = jnp.zeros_like(s0)
# Solve the fixed-point equation
x, (step_norm, root_norm, iterations) = FixedPointRoot(fun, x0)
print('-----------------------------------------')
print(f'System of fixed-point equations is solved: {jnp.allclose(x,fun(x)[0])}.')
print(f'Probabilities are identical: {jnp.allclose(s0, my_logit(x))}.')
print('-----------------------------------------')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
FixedPointJAX-0.0.25.tar.gz
(4.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file FixedPointJAX-0.0.25.tar.gz.
File metadata
- Download URL: FixedPointJAX-0.0.25.tar.gz
- Upload date:
- Size: 4.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
12d5a9b5ed9d0b5476d610e299674d11cd9f5e8bc37c663b700be7ac3d539146
|
|
| MD5 |
7ab29e2170b8d32a5f431b9de502c843
|
|
| BLAKE2b-256 |
d2dcb2849c82991d9a149d2de0bc8106d2e807bd108404e9a635bcaa24fc0ca1
|
File details
Details for the file FixedPointJAX-0.0.25-py3-none-any.whl.
File metadata
- Download URL: FixedPointJAX-0.0.25-py3-none-any.whl
- Upload date:
- Size: 5.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1747538b54ed6f15dc4878a1a3ab48f75a890b9bd9b0b449f7f6f152d2087357
|
|
| MD5 |
38e02299a888649f2be0beb2dba14e82
|
|
| BLAKE2b-256 |
4df5ec6da53b3df12da7aa5d75ca67df0b6631e298680fcdf6e107e99c0d0932
|