Skip to main content

Turn SymPy expressions into trainable JAX expressions.

Project description

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)

Where:

  • expressions is a PyTree of SymPy expressions.
  • extra_funcs is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
  • make_array is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other tools in the JAX ecosystem

Neural networks: Equinox.

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

Disclaimer

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympy2jax-0.0.3.tar.gz (8.8 kB view details)

Uploaded Source

Built Distribution

sympy2jax-0.0.3-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file sympy2jax-0.0.3.tar.gz.

File metadata

  • Download URL: sympy2jax-0.0.3.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 c3c2d309978a354ef2820355c191731fab4f1d0ee5b3f351f3722ef71c203c09
MD5 f0631ca85f0fed860bba9316a12ac2dc
BLAKE2b-256 6f905a1c2af5e4db7982abce4431a5a4c8df30f3031d3fa5bde4432ac3ab59a0

See more details on using hashes here.

File details

Details for the file sympy2jax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: sympy2jax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for sympy2jax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d63f9b5af98441f37bc4dfdf17a90c88afb2593be1c58e0e80e9e601df10203e
MD5 2e7551d07d5fc4b0579df9f26b2e2da5
BLAKE2b-256 1a3ef8697d18aa6de36a32e4ef25b08d5aeb4257bd4983f308e391c8d0a3f06b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page