Tools for JAX.
Project description
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass and mypy_plugin decorator
dataclasss
that facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:the ability to mark fields as static (not available in chex.dataclass),
a MyPy plugin, and
a display method that produces formatted text according to the tree structure.
A fixed_point finding library heavily based on fax. Our library
supports stochastic iterated functions, and
uses dataclasses instead of closures to avoid leaking JAX tracers.
A shim for the gradient transformation library optax that supports:
easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
gradient transformation objects that can be passed dynamically to jitted functions, and
generic type annotations.
A pretty printer
print_generic
for aggregate and vector types, including dataclasses. (See display.) It features:a version for printing traced values
tapped_print_generic
,decoding size of the batched axes when printing ordinary and traced values,
colorized tree output for aggregate structures, and
formatted tabular output for arrays (or statistics when there’s no room for tabular output).
Minor components
Tjax also includes:
Versions of
custom_vjp
andcustom_jvp
that support being used on methods:custom_vjp_method
andcustom_vjp_method
(See shims.)Tools for working with cotangents. (See cotangent_tools.)
JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrate
and Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate
. (See leaky_integral.)An improved version of
jax.tree_util.Partial
. (See partial.)A testing function
assert_tree_allclose
that automatically produces testing code. And, a related functiontree_allclose
. (See testing.)Basic tools like
divide_where
. (See tools.)
Contribution guidelines
Conventions: PEP8.
How to run tests:
pytest .
How to clean the source:
ruff .
pyright
mypy
isort .
pylint tjax tests
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.