Skip to main content

Simple fixed-point solver implemented in JAX

Project description

PyPI version CI CD

Fixed-point solver

FixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX. The implementation allow the user to solve the system of fixed point equations by standard fixed point iterations and the SQUAREM accelerator, see Du and Varadhan (2020).

Installation

pip install fxp-jax

Usage

import jax.numpy as jnp
from jax import random

from fxp_jax import fxp_root

# Define the logit probabilities
def logit(x, axis=1):
	nominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
	denominator = jnp.sum(nominator, axis=axis, keepdims=True)
	return nominator / denominator
	
# Define the function for the fixed-point iteration
def fun(x):
	s = logit(x)
	z = jnp.log(s0 / s)
	return x + z, z

# Dimensions of system of fixed-point equations
I, J = 3, 4

# Simulate probabilities
s0 = random.dirichlet(key=random.PRNGKey(123), alpha=jnp.ones((J,)), shape=(I,))

# Initial guess
x0 = jnp.zeros_like(s0)

print('--------------------------------------------------------')
# Solve the fixed-point equation
fxp = fxp_root(
        fun,
    )
result = fxp.solve(guess=jnp.zeros_like(s0), accelerator="None")
print('--------------------------------------------------------')
print(f'System of fixed-point equations is solved: {jnp.allclose(result.x,fun(result.x)[0])}.')
print(f'Probabilities are identical: {jnp.allclose(s0, logit(result.x))}.')
print('--------------------------------------------------------')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fxp_jax-0.1.21.tar.gz (4.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fxp_jax-0.1.21-py3-none-any.whl (4.5 kB view details)

Uploaded Python 3

File details

Details for the file fxp_jax-0.1.21.tar.gz.

File metadata

  • Download URL: fxp_jax-0.1.21.tar.gz
  • Upload date:
  • Size: 4.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fxp_jax-0.1.21.tar.gz
Algorithm Hash digest
SHA256 4f17c02203d32f892f292b906176dbb54de16c5122c84d1c915ca1721ab47e00
MD5 f28a56c38d585214389fb040ca0cf112
BLAKE2b-256 6e4f5d39822a5773c82fa950b85303139a90e09eb8943c57a6e2b69d22cb5504

See more details on using hashes here.

Provenance

The following attestation bundles were made for fxp_jax-0.1.21.tar.gz:

Publisher: cd.yml on esbenscriver/fxp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file fxp_jax-0.1.21-py3-none-any.whl.

File metadata

  • Download URL: fxp_jax-0.1.21-py3-none-any.whl
  • Upload date:
  • Size: 4.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for fxp_jax-0.1.21-py3-none-any.whl
Algorithm Hash digest
SHA256 2fbb2f28393f645e57b60200b3f07af26817ba31ec5dfb10d70c0dfe39df6601
MD5 e17de67cb952c1eff27c52f91206b731
BLAKE2b-256 b6278708c69e6439fe0e7a438bdead9438d1fc2af060cfa1fc2bb245a16cdb6d

See more details on using hashes here.

Provenance

The following attestation bundles were made for fxp_jax-0.1.21-py3-none-any.whl:

Publisher: cd.yml on esbenscriver/fxp-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page