Skip to main content

Simple fixed-point solver implemented in JAX

Project description

PyPI CI

Fixed-point solver

FixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX. The implementation allow the user to solve the system of fixed point equations by standard fixed point iterations and the SQUAREM accelerator, see Du and Varadhan (2020).

  • Strives to be minimal
  • Has no dependencies other than JAX

Installation

pip install FixedPointJAX

Usage

import jax.numpy as jnp
from jax import random

from FixedPointJAX import FixedPointRoot

# Define the logit probabilities
def logit(x, axis=1):
	nominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
	denominator = jnp.sum(nominator, axis=axis, keepdims=True)
	return nominator / denominator
	
# Define the function for the fixed-point iteration
def fxp(x):
	s = logit(x)
	z = jnp.log(s0 / s)
	return x + z, z

# Dimensions of system of fixed-point equations
I, J = 3, 4

# Simulate probabilities
s0 = random.dirichlet(key=random.PRNGKey(123), alpha=jnp.ones((J,)), shape=(I,))

# Initial guess
x0 = jnp.zeros_like(s0)

print('-----------------------------------------')
# Solve the fixed-point equation
x, (step_norm, root_norm, iterations) = FixedPointRoot(fxp, x0)
print('-----------------------------------------')
print(f'System of fixed-point equations is solved: {jnp.allclose(x,fxp(x)[0])}.')
print(f'Probabilities are identical: {jnp.allclose(s0, logit(x))}.')
print('-----------------------------------------')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fxp_jax-0.1.20.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fxp_jax-0.1.20-py3-none-any.whl (4.6 kB view details)

Uploaded Python 3

File details

Details for the file fxp_jax-0.1.20.tar.gz.

File metadata

  • Download URL: fxp_jax-0.1.20.tar.gz
  • Upload date:
  • Size: 4.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fxp_jax-0.1.20.tar.gz
Algorithm Hash digest
SHA256 6954e3573c0d039fce4a255d080e4b5baadc5cb6f0f29ea3b819bbc30b85b436
MD5 340ff5921a55db4152eaa8a1df562b62
BLAKE2b-256 6b31c4f60a356741440aec10a1a0f367825399ed927becd7ac063fab3c480b05

See more details on using hashes here.

Provenance

The following attestation bundles were made for fxp_jax-0.1.20.tar.gz:

Publisher: cd.yml on esbenscriver/fxp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fxp_jax-0.1.20-py3-none-any.whl.

File metadata

  • Download URL: fxp_jax-0.1.20-py3-none-any.whl
  • Upload date:
  • Size: 4.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fxp_jax-0.1.20-py3-none-any.whl
Algorithm Hash digest
SHA256 6b55313bbeae28e5fffab0f80f16905b3d2cfb0250d52bbbf909f920d85f5e72
MD5 bbd619250b55f3c20f4f54e5cb65f512
BLAKE2b-256 58ecf3cf1f9bc30a17e4957c76ce9d4e834c7b80d61370f6dd479d72d79b8bfb

See more details on using hashes here.

Provenance

The following attestation bundles were made for fxp_jax-0.1.20-py3-none-any.whl:

Publisher: cd.yml on esbenscriver/fxp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page