Tools for JAX.
Project description
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass decorator
dataclasss
that facilitates defining JAX trees, and has a MyPy plugin. (See dataclass and mypy_plugin.)A fixed point finding library heavily based on fax. Our library supports stochastic iterated functions, and avoids leaking JAX tracers. (See fixed_point.)
Minor components
Tjax also includes:
A pretty printer
print_generic
for aggregate and vector types, including dataclasses. (See display.)Versions of
custom_vjp
andcustom_jvp
that support both static and non-differentiable arguments. (See shims.)Tools for working with cotangents:
copy_cotangent
andprint_cotangent
. (See cotangent_tools.)A random number generator class
Generator
. (See generator.)JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrate
and Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate
. (See leaky_integral.)An improved version of
jax.tree_util.Partial
. (See partial.)A Matplotlib trajectory plotter
PlottableTrajectory
. (See plottable_trajectory.)A testing function
assert_jax_allclose
that automatically produces testing code, and the related functionjax_allclose
. (See testing.)Basic tools
sum_tensors
andis_scalar
. (See tools.)
Also, see the documentation.
Contribution guidelines
Conventions: PEP8.
How to run tests:
pytest .
How to clean the source:
isort tjax
pylint tjax
mypy tjax
flake8 tjax
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.