Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
  • make_array is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other libraries in the JAX ecosystem

jaxtyping: type annotations for shape/dtype of arrays.

Equinox: neural networks.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.5.tar.gz (9.5 kB view details)

Uploaded Source

Built Distribution

sympy2jax-0.0.5-py3-none-any.whl (13.6 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.5.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.5.tar.gz
  • Upload date:
  • Size: 9.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for sympy2jax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 8a039d7b62c18b4c9f72dc92f260d7ec48ce253bcb9d4913f675de61a7afc0c3
MD5 d1c56fe42ea64eaa70bbabe516ff3010
BLAKE2b-256 9027af7c5bec9a55a05bc9fb8c2c95165f4fa5910753c22077218782004d6a33

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 13.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.6

File hashes

Hashes for sympy2jax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 0ab3a78199ecd44b970728d4724d9fb175679126c6b246a7488c0e0349a11717
MD5 7adc05ee6524f1d13b6179af6dcd3753
BLAKE2b-256 378581076aca3de8521f39876384fa8315bf5f88f736b602fe81be1757c6ee67

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page