Skip to main content

JAX bindings and operations for RoughPy

Project description

roughpy-jax

roughpy-jax provides JAX bindings and operations for RoughPy. It provides stream classes and dense algebraic objects (such as free tensors, shuffle tensors, and elements of the free Lie algebra) for computational rough path theory, and supports JAX JIT-compilation and differentiation.

This library is currently in an alpha stage. The API is still evolving, and some features are incomplete or subject to change as the package matures.

What This Package Provides

roughpy-jax builds on top of roughpy and jax and currently includes:

  • dense tensor, shuffle tensor, and Lie algebra wrappers
  • algebraic operations such as multiplication, exponentials, logarithms, CBH, pairings, and adjoint operations
  • JAX-compatible derivative and adjoint-derivative rules for core operations
  • interval and partition types for stream queries
  • stream types including piecewise Abelian streams and Lie increment streams

Installation

Once published, roughpy-jax can be installed from PyPI with:

pip install roughpy-jax

The package requires the latest version of roughpy (0.3.0) and Python 3.11 or newer.

Release artifacts can also be downloaded from the GitHub Releases page for this repository.

Installing From Source

Installing from source is useful when working on the package itself or testing changes before a release. A working C/C++ toolchain and CMake-compatible build environment are required.

Clone the repository and install it into a virtual environment:

git clone https://github.com/datasig-ac-uk/roughpy-jax.git
cd roughpy-jax
python -m venv .venv
. .venv/bin/activate
pip install -U pip
pip install .

If you are using uv, the equivalent workflow is:

uv venv
. .venv/bin/activate
uv pip install .

Streams and Intervals

Streams are the central object of RoughPy, and so too in roughpy-jax. Like RoughPy itself, roughpy-jax works carefully with intervals and stream queries.

Current stream-facing functionality includes:

  • PiecewiseAbelianStream for streams built from piecewise log-signature data
  • LieIncrementStream for dyadic-cache-backed querying of log-signatures and signatures over intervals

These pieces are intended to make it practical to move between algebraic objects and stream queries within JAX-oriented workflows.

API Differences From RoughPy

There are some deliberate API differences compared to roughpy.

Context objects are not used in roughpy-jax. Instead, explicit basis objects and conversion functions handle translation between algebraic objects with different configurations. At present, only depth changes are supported explicitly.

Streams may have several associated bases, depending on the stream type. These can include:

  • the basis of the underlying data
  • the basis used for stored or cached data
  • the basis used for answering queries

These bases do not need to be identical, but they do need to be compatible. Exactly which bases exist, and whether they are user-facing, is stream-type dependent.

Only very basic interval support is currently implemented. This area still needs to be expanded.

It might not be possible to convert RoughPy objects directly to roughpy-jax equivalents.

JAX Notes

All algebra objects and algebraic operations are intended to support JIT and are fully differentiable. In particular, the package provides explicit derivative and adjoint-derivative functions alongside the corresponding primal operations, and these are the functions whose type information should be relied upon.

Stream objects are more limited. Some stream types may support JIT in some contexts, but stream support is not yet uniform. In particular, LieIncrementStream is not currently registered as a pytree because of technical limitations that have not yet been resolved.

There is also an important JAX-specific subtlety in reverse-mode code. Because JAX tree handling does not preserve the intended algebraic type information in all backward-pass cotangents, cotangents may be represented using the wrong algebra wrapper. For example, a value that should be treated as a shuffle tensor may arrive as a free tensor, or vice versa. To handle this, internal JAX-facing code applies corrective conversions on incoming and outgoing cotangents. The public derivative and adjoint-derivative APIs expose the correct algebraic types.

Testing

The test suite exercises both the pure Python layer and the compiled CPU backend. Locally, the main test command is:

pytest -m "not extra" roughpy_jax/tests

Wheel builds are tested through cibuildwheel in CI, and release artifacts are validated before publishing.

Example

For examples on how to use the higher level stream objects, see the examples/ folder. Here the 'words' example from the RoughPy documentation has been converted to use the new Stream objects.

Support

If you hit a bug or need a feature, open an issue on GitHub. Bug reports with a minimal reproducer are the most useful.

Contributing

Contributions are welcome, especially:

  • bug fixes
  • tests
  • documentation improvements
  • examples and API polish

If you plan to make a larger change, open an issue first so the design can be discussed before implementation.

License

roughpy-jax is licensed under the BSD 3-Clause License. See LICENSE.txt.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

roughpy_jax-1.0.0.tar.gz (88.9 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

roughpy_jax-1.0.0-cp311-abi3-win_amd64.whl (185.7 kB view details)

Uploaded CPython 3.11+Windows x86-64

roughpy_jax-1.0.0-cp311-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (192.0 kB view details)

Uploaded CPython 3.11+manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

roughpy_jax-1.0.0-cp311-abi3-macosx_11_0_arm64.whl (116.8 kB view details)

Uploaded CPython 3.11+macOS 11.0+ ARM64

File details

Details for the file roughpy_jax-1.0.0.tar.gz.

File metadata

  • Download URL: roughpy_jax-1.0.0.tar.gz
  • Upload date:
  • Size: 88.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for roughpy_jax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 06a271d98c91bf8814793c68497e47096a63c211c5b9a9e86e47c54dc6837ce7
MD5 98e522775c42c823c53213aabb356ea4
BLAKE2b-256 4f23c4ec8a69593767db335a5afd86ad505e6a5eb5a7fb38150725319124fdff

See more details on using hashes here.

Provenance

The following attestation bundles were made for roughpy_jax-1.0.0.tar.gz:

Publisher: release.yml on datasig-ac-uk/roughpy-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file roughpy_jax-1.0.0-cp311-abi3-win_amd64.whl.

File metadata

  • Download URL: roughpy_jax-1.0.0-cp311-abi3-win_amd64.whl
  • Upload date:
  • Size: 185.7 kB
  • Tags: CPython 3.11+, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for roughpy_jax-1.0.0-cp311-abi3-win_amd64.whl
Algorithm Hash digest
SHA256 f5ca535e1c5a3232e557d8aa76ed36050ac1b7f6e961f5a81b412557127a78a1
MD5 17af0a9409154abf8089408c58ef063c
BLAKE2b-256 32060708270bfdb2aebad55f1f80f4225ee0bf5d633d88479f7ad976cd0eaa70

See more details on using hashes here.

Provenance

The following attestation bundles were made for roughpy_jax-1.0.0-cp311-abi3-win_amd64.whl:

Publisher: release.yml on datasig-ac-uk/roughpy-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file roughpy_jax-1.0.0-cp311-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for roughpy_jax-1.0.0-cp311-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 7c2e38553d256ca91f08d1cfb6729ec875278fe4f956bad9dbd6695bac74e37c
MD5 1b356e3a9ade90099240292441939a5a
BLAKE2b-256 b9800985d43924b0de9cbbb6deab3ab9b48aaee6e3bb59c8083806b8b071cbbb

See more details on using hashes here.

Provenance

The following attestation bundles were made for roughpy_jax-1.0.0-cp311-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: release.yml on datasig-ac-uk/roughpy-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file roughpy_jax-1.0.0-cp311-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for roughpy_jax-1.0.0-cp311-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 0736344d55300921ee86c8a90e6a7b4fc964f05a2697832a53780745aca23b1b
MD5 52f9358c5b98e3436e33b5101be6b1de
BLAKE2b-256 b50369cd9dedc9adc0e3d9914e6fa24ff1caf0e35193b67b887312cfeb268416

See more details on using hashes here.

Provenance

The following attestation bundles were made for roughpy_jax-1.0.0-cp311-abi3-macosx_11_0_arm64.whl:

Publisher: release.yml on datasig-ac-uk/roughpy-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page