Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions, extra_funcs=None)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other tools in the JAX ecosystem

Neural networks: Equinox.

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.2.tar.gz (8.7 kB view details)

Uploaded Source

Built Distribution

sympy2jax-0.0.2-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.2.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.2.tar.gz
  • Upload date:
  • Size: 8.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 d1028a7efd877d4a201ad179f170224a0256722388ed962722eb19c99b44ac4a
MD5 959fc23c4a59218a162be686c6d5a49a
BLAKE2b-256 a03a3e43799fcf73c705dbb35818f1bd160b6e685b208ac4abba49006d944503

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 23aa2271f43d1da1ba383085937e742dc8de43e8e053ca5bdcde16c9bea0fbd9
MD5 e2c9f44e9192df1403a3f761fe3a172f
BLAKE2b-256 d30e76b500ef7e328dcc09d5b5cc08096708e0b2c35f272f30d85150f66b50aa

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page