Execute runtime assertions, indexing checks, and more if `jax` code is not traced.
Project description
Execute runtime assertions, indexing checks, and more if jax
code is not traced.
>>> import ifnt >>> import jax >>> from jax import numpy as jnp >>> >>> def safe_log(x): ... ifnt.testing.assert_array_less(0, x) ... return jnp.log(x) >>> >>> safe_log(-1) Traceback (most recent call last): ... AssertionError: Arrays are not less-ordered <BLANKLINE> Mismatched elements: 1 / 1 (100%) Max absolute difference: 1 Max relative difference: 1. x: array(0) y: array(-1) >>> jax.jit(safe_log)(-1) Array(nan, dtype=float32, weak_type=True)
Installation
$ pip install jax-ifnt
Relationship to chex
DeepMind’s chex provides similar, often complementary, assertions. While chex requires runtime assertions to be “functionalized” with chex.chexify
, ifnt will skip assertions in traced code. This facilitates, for example, verifying that indices are not out of bounds.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jax_ifnt-0.1.3.tar.gz
(7.7 kB
view hashes)