Tools for JAX.
Project description
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass and mypy_plugin decorator
dataclasssthat facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:the ability to mark fields as static (not available in chex.dataclass),
a MyPy plugin, and
a display method that produces formatted text according to the tree structure.
A fixed_point finding library heavily based on fax. Our library
supports stochastic iterated functions, and
uses dataclasses instead of closures to avoid leaking JAX tracers.
A shim for the gradient transformation library optax that supports:
easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
gradient transformation objects that can be passed dynamically to jitted functions, and
generic type annotations.
Minor components
Tjax also includes:
A pretty printer
print_genericfor aggregate and vector types, including dataclasses. (See display.)Versions of
custom_vjpandcustom_jvpthat support being used on methods. (See shims.)Tools for working with cotangents. (See cotangent_tools.)
JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrateand Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate. (See leaky_integral.)An improved version of
jax.tree_util.Partial. (See partial.)A Matplotlib trajectory plotter
PlottableTrajectory. (See plottable_trajectory.)A testing function
assert_tree_allclosethat automatically produces testing code. And, a related functiontree_allclose. (See testing.)Basic tools like
divide_where. (See tools.)
Also, see the documentation.
Contribution guidelines
Conventions: PEP8.
How to run tests:
pytest .How to clean the source:
isort .mypy .pylint tjax testspflake8 .
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tjax-0.22.4.tar.gz.
File metadata
- Download URL: tjax-0.22.4.tar.gz
- Upload date:
- Size: 36.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.0b3.dev0 CPython/3.10.4 Linux/5.15.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d2ae6b275a14ff5d7350add95f57ffd6671549cb1d8dc7baf7736930685dbca0
|
|
| MD5 |
f14bc87400ecf76974086089029b08ae
|
|
| BLAKE2b-256 |
f58dcfcf21f2e596f551be6a0bf51f3afadf6c97066647eeaf9556c083013d7e
|
File details
Details for the file tjax-0.22.4-py3-none-any.whl.
File metadata
- Download URL: tjax-0.22.4-py3-none-any.whl
- Upload date:
- Size: 46.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.0b3.dev0 CPython/3.10.4 Linux/5.15.0-52-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1d3769a68256e39f31f382206b6eaa05645d5265b61162cd5449f636e269d26d
|
|
| MD5 |
897a6ff2d6a189fa3d1e2f3c1dbc529a
|
|
| BLAKE2b-256 |
116d54ea928517da7e3ffd9fa606cdde4f161673db3fdf0b925bfd7f5b2cc1a7
|