Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions)

Where expressions is a PyTree of SymPy expressions.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other tools in the JAX ecosystem

Neural networks: Equinox.

Numerical differential equation solvers: Diffrax.

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.1.tar.gz (7.9 kB view details)

Uploaded Source

Built Distribution

sympy2jax-0.0.1-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.1.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.1.tar.gz
  • Upload date:
  • Size: 7.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.13

File hashes

Hashes for sympy2jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 a3e759b631434161788ae7c82ba89198814bb799c9c41b29a41267fd4ca4da2a
MD5 519f63fbaa01bf8e24c764763e8c2959
BLAKE2b-256 e1b323ba580fbe7e225cf22b9b1f421760d2bb3cd2b839b60b20fcbb58a50f92

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.13

File hashes

Hashes for sympy2jax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b2c902c96103a4dee72eff87901e62bedc3208142c423236bef5ba2bf1c39b15
MD5 240822927dacd36ad4179dadcabe3376
BLAKE2b-256 e25ba88140f253c946c0adb0001a12bdaae162508f4e4e56e60977dd3fdd214c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page