Tools for JAX.
Project description
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass and mypy_plugin decorator
dataclasss
that facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:the ability to mark fields as static (not available in chex.dataclass),
a MyPy plugin, and
a display method that produces formatted text according to the tree structure.
A fixed_point finding library heavily based on fax. Our library
supports stochastic iterated functions, and
uses dataclasses instead of closures to avoid leaking JAX tracers.
A shim for the gradient transformation library optax that supports:
easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
gradient transformation objects that can be passed dynamically to jitted functions, and
generic type annotations.
Minor components
Tjax also includes:
A pretty printer
print_generic
for aggregate and vector types, including dataclasses. (See display.)Versions of
custom_vjp
andcustom_jvp
that support being used on methods. (See shims.)Tools for working with cotangents. (See cotangent_tools.)
JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrate
and Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate
. (See leaky_integral.)An improved version of
jax.tree_util.Partial
. (See partial.)A Matplotlib trajectory plotter
PlottableTrajectory
. (See plottable_trajectory.)A testing function
assert_tree_allclose
that automatically produces testing code. And, a related functiontree_allclose
. (See testing.)Basic tools like
divide_where
. (See tools.)
Also, see the documentation.
Contribution guidelines
Conventions: PEP8.
How to run tests:
pytest .
How to clean the source:
isort .
mypy .
pylint tjax tests
pflake8 .
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.