Skip to main content

Execute runtime assertions, indexing checks, and more if `jax` code is not traced.

Project description

https://github.com/tillahoffmann/ifnt/actions/workflows/build.yml/badge.svg https://readthedocs.org/projects/ifnt/badge/?version=latest

Execute runtime assertions, indexing checks, and more if jax code is not traced.

>>> import ifnt
>>> import jax
>>> from jax import numpy as jnp
>>>
>>> def safe_log(x):
...     ifnt.testing.assert_array_less(0, x)
...     return jnp.log(x)
>>>
>>> safe_log(-1)
Traceback (most recent call last):
...
AssertionError: Arrays are not less-ordered
<BLANKLINE>
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 1
Max relative difference: 1.
x: array(0)
y: array(-1)
>>> jax.jit(safe_log)(-1)
Array(nan, dtype=float32, weak_type=True)

Installation

$ pip install jax-ifnt

Relationship to chex

DeepMind’s chex provides similar, often complementary, assertions. While chex requires runtime assertions to be “functionalized” with chex.chexify, ifnt will skip assertions in traced code. This facilitates, for example, verifying that indices are not out of bounds.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ifnt-0.1.3.tar.gz (7.7 kB view hashes)

Uploaded Source

Built Distribution

jax_ifnt-0.1.3-py3-none-any.whl (8.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page