Tools for JAX.
Project description
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass decorator
dataclasss
that facilitates defining JAX trees, and has a MyPy plugin. (See dataclass and mypy_plugin.)A fixed point finding library heavily based on fax. Our library (fixed_point):
supports stochastic iterated functions,
uses dataclasses instead of closures to avoid leaking JAX tracers, and
supports higher-order differentiation.
Minor components
Tjax also includes:
A pretty printer
print_generic
for aggregate and vector types, including dataclasses. (See display.)Versions of
custom_vjp
andcustom_jvp
that support being used on methods. (See shims.)Tools for working with cotangents. (See cotangent_tools.)
A random number generator class
Generator
. (See generator.)JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrate
and Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate
. (See leaky_integral.)An improved version of
jax.tree_util.Partial
. (See partial.)A Matplotlib trajectory plotter
PlottableTrajectory
. (See plottable_trajectory.)A testing function
assert_tree_allclose
that automatically produces testing code. And, a related functiontree_allclose
. (See testing.)Basic tools like
divide_where
. (See tools.)
Also, see the documentation.
Contribution guidelines
Conventions: PEP8.
How to run tests:
pytest .
How to clean the source:
isort tjax
pylint tjax
mypy tjax
flake8 tjax
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.