Tools for JAX.
This repository implements a variety of tools for the differential programming library JAX.
Tjax’s major components are:
the ability to mark fields as static (not available in chex.dataclass),
a MyPy plugin, and
a display method that produces formatted text according to the tree structure.
supports stochastic iterated functions, and
uses dataclasses instead of closures to avoid leaking JAX tracers.
easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
gradient transformation objects that can be passed dynamically to jitted functions, and
generic type annotations.
A pretty printer
print_genericfor aggregate and vector types, including dataclasses. (See display.) It features:
a version for printing traced values
decoding size of the batched axes when printing ordinary and traced values,
colorized tree output for aggregate structures, and
formatted tabular output for arrays (or statistics when there’s no room for tabular output).
Tjax also includes:
custom_jvpthat support being used on methods. (See shims.)
Tools for working with cotangents. (See cotangent_tools.)
leaky_integrateand Ornstein-Uhlenbeck process iteration
diffused_leaky_integrate. (See leaky_integral.)
An improved version of
jax.tree_util.Partial. (See partial.)
A testing function
assert_tree_allclosethat automatically produces testing code. And, a related function
tree_allclose. (See testing.)
Basic tools like
divide_where. (See tools.)
How to run tests:
How to clean the source:
pylint tjax tests
Release history Release notifications | RSS feed
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.