Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
  • make_array is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other tools in the JAX ecosystem

Neural networks: Equinox.

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.4.tar.gz (8.8 kB view details)

Uploaded Source

Built Distribution

sympy2jax-0.0.4-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.4.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.4.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.4.tar.gz
Algorithm Hash digest
SHA256 04689f7a872a848aa4bf62369fbb74948b8efadc7fc993957b863510084eb2d5
MD5 573e201d543b8a1677df791996f1da48
BLAKE2b-256 89d2a17dae53ab1a0e591dd11ac5814cc676bac454cc76912dd5fc1d2b2242bb

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 92c68b8870ae3a01f63c962d9379cfe2c1884caff20e5433dbf1a94603a1eaa3
MD5 a2e59ab77e3b0b777475342c38a8ca62
BLAKE2b-256 bcbe34fbd531285203ed6ba43932b9e8748b490ab8542c1f0c7b5ead13cfcebc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page