Skip to main content

Tools for JAX.

Project description

PyPI - Version PyPI - Python Version

This repository implements a variety of tools for the differential programming library JAX.

Major components

Tjax’s major components are:

  • A dataclass decorator dataclass that facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:

    • the ability to mark fields as static (not available in chex.dataclass), and

    • a display method that produces formatted text according to the tree structure.

  • A shim for the gradient transformation library optax that supports:

    • easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,

    • gradient transformation objects that can be passed dynamically to jitted functions, and

    • generic type annotations.

  • A pretty printer print_generic for aggregate and vector types, including dataclasses. (See display.) It features:

    • support for traced values,

    • colorized tree output for aggregate structures, and

    • formatted tabular output for arrays (or statistics when there’s no room for tabular output).

Minor components

Tjax also includes:

  • Versions of custom_vjp and custom_jvp that support being used on methods: custom_vjp_method and custom_vjp_method (See shims.)

  • Tools for working with cotangents. (See cotangent_tools.)

  • JAX tree registration for NetworkX graph types. (See graph.)

  • Leaky integration leaky_integrate and Ornstein-Uhlenbeck process iteration diffused_leaky_integrate. (See leaky_integral.)

  • An improved version of jax.tree_util.Partial. (See partial.)

  • A testing function assert_tree_allclose that automatically produces testing code. And, a related function tree_allclose. (See testing.)

  • Basic tools like divide_where. (See tools.)

Contribution guidelines

  • Conventions: PEP8.

  • How to run tests: pytest .

  • How to clean the source:

    • ruff .

    • pyright

    • mypy

    • isort .

    • pylint tjax tests

Project details


Release history Release notifications | RSS feed

This version

1.1.0

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tjax-1.1.0.tar.gz (107.9 kB view details)

Uploaded Source

Built Distribution

tjax-1.1.0-py3-none-any.whl (45.1 kB view details)

Uploaded Python 3

File details

Details for the file tjax-1.1.0.tar.gz.

File metadata

  • Download URL: tjax-1.1.0.tar.gz
  • Upload date:
  • Size: 107.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.2

File hashes

Hashes for tjax-1.1.0.tar.gz
Algorithm Hash digest
SHA256 349cef297b2f906b45dd0d576f56840c3620b2d0ec9901256df88c9583432fa8
MD5 039572d3c649868c2c3f1a4a633acf4b
BLAKE2b-256 3db62e06f0c8eff22640e58ed164005700bb9a2a7150b7f5af7abc6072bdacb5

See more details on using hashes here.

File details

Details for the file tjax-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: tjax-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 45.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.27.2

File hashes

Hashes for tjax-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fe73c4959d061d570be98a40a0203f58d5b3dca47c571cfdfca90c710c4b7a77
MD5 b25f785076f4b127e3c1e96bd8c17676
BLAKE2b-256 6a3220a89ea76e32141098902c21ddca2946575d76653f0f473fc352b3a8b45f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page