Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
  • make_array is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

See also: other libraries in the JAX ecosystem

Always useful
Equinox: neural networks and everything not already in core JAX!
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).
paramax: parameterizations and constraints for PyTrees.

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.8.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sympy2jax-0.0.8-py3-none-any.whl (13.8 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.8.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.8.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for sympy2jax-0.0.8.tar.gz
Algorithm Hash digest
SHA256 17bacf5503a3813011dce97c3b507bc7f66cd550dd677b0b7b1074e13f54a1d7
MD5 21845f9853183915f457aefc0a5fde9e
BLAKE2b-256 3869f07c85d10904e24eac38e6563abfb33835c6c1f958693d14cf999f0de24e

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 13.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for sympy2jax-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 da86c3915c69142354d04714ae449a56b4896b043f789bce839fed3d01068745
MD5 8ef7d0700c797eac7cc3c232dd5f7d22
BLAKE2b-256 10067df763dd5f981edf68ad1c371468f0bbf0673f4fc9af82f5fd8a4936fd7d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page