Skip to main content

Fixed-point iterations for root finding implemented in JAX

Project description

Fixed-point solver

FixedPointJAX is a simple implementation of a fixed-point iteration algorithm for root finding in JAX.

  • Strives to be minimal
  • Has no dependencies other than JAX

Installation

pip install FixedPointJAX

Usage

import jax.numpy as jnp
from jax import random

from FixedPointJAX import FixedPointRoot

# Define the logit probabilities
def my_logit(x, axis=0):
	nominator = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
	denominator = jnp.sum(nominator, axis=axis, keepdims=True)
	return nominator / denominator
	
# Define the function for the fixed-point iteration
def my_fxp(x,s0):
	s = my_logit(x)
	z = jnp.log(s0 / s)
	return x + z, z
print('-----------------------------------------')
# Dimensions of system of fixed-point equations
shape = (3, 4)

# Simulate probabilities
s0 = my_logit(random.uniform(key=random.PRNGKey(123), shape=shape))

# Set up fixed-point equation
fun = lambda x: my_fxp(x,s0)

# Initial guess
x0 = jnp.zeros_like(s0)

# Solve the fixed-point equation
x, (step_norm, root_norm, iterations) = FixedPointRoot(fun, x0)
print('-----------------------------------------')
print(f'System of fixed-point equations is solved: {jnp.allclose(x,fun(x)[0])}.')
print(f'Probabilities are identical: {jnp.allclose(s0, my_logit(x))}.')
print('-----------------------------------------')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

FixedPointJAX-0.0.25.tar.gz (4.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

FixedPointJAX-0.0.25-py3-none-any.whl (5.3 kB view details)

Uploaded Python 3

File details

Details for the file FixedPointJAX-0.0.25.tar.gz.

File metadata

  • Download URL: FixedPointJAX-0.0.25.tar.gz
  • Upload date:
  • Size: 4.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.8

File hashes

Hashes for FixedPointJAX-0.0.25.tar.gz
Algorithm Hash digest
SHA256 12d5a9b5ed9d0b5476d610e299674d11cd9f5e8bc37c663b700be7ac3d539146
MD5 7ab29e2170b8d32a5f431b9de502c843
BLAKE2b-256 d2dcb2849c82991d9a149d2de0bc8106d2e807bd108404e9a635bcaa24fc0ca1

See more details on using hashes here.

File details

Details for the file FixedPointJAX-0.0.25-py3-none-any.whl.

File metadata

  • Download URL: FixedPointJAX-0.0.25-py3-none-any.whl
  • Upload date:
  • Size: 5.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.8

File hashes

Hashes for FixedPointJAX-0.0.25-py3-none-any.whl
Algorithm Hash digest
SHA256 1747538b54ed6f15dc4878a1a3ab48f75a890b9bd9b0b449f7f6f152d2087357
MD5 38e02299a888649f2be0beb2dba14e82
BLAKE2b-256 4df5ec6da53b3df12da7aa5d75ca67df0b6631e298680fcdf6e107e99c0d0932

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page