Skip to main content

Turning SymPy expressions into PyTorch modules.

Project description

sympytorch

Turn SymPy expressions into PyTorch Modules.

SymPy floats (optionally) become trainable parameters. SymPy symbols are inputs to the Module.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympytorch

Requires Python 3.7+ and PyTorch 1.6.0+ and SymPy 1.7.1+.

Example

import sympy, torch, sympytorch

x = sympy.symbols('x_name')
cosx = 1.0 * sympy.cos(x)
sinx = 2.0 * sympy.sin(x)
mod = sympytorch.SymPyModule(expressions=[cosx, sinx])

x_ = torch.rand(3)
out = mod(x_name=x_)  # out has shape (3, 2)

assert torch.equal(out[:, 0], x_.cos())
assert torch.equal(out[:, 1], 2 * x_.sin())
assert out.requires_grad  # from the two Parameters initialised as 1.0 and 2.0
assert {x.item() for x in mod.parameters()} == {1.0, 2.0}

API

sympytorch.SymPyModule(*, expressions, extra_funcs=None)

Where:

  • expressions is a list of SymPy expressions.
  • extra_funcs is a dictionary mapping from custom sympy.Functions to their PyTorch implementation. Defaults to no extra functions.

Instances of SymPyModule can be called, passing the values of the symbols as in the above example.

SymPyModule has a method .sympy(), which returns the corresponding list of SymPy expressions. (Which may not be the same as the expressions it was initialised with, if the values of its Parameters have been changed, i.e. have been learnt.)

Wrapping floats in sympy.UnevaluatedExpr will cause them not to be trained, by registering them as buffers rather than parameters.

sympytorch.hide_floats(expression)

As a convenience, hide_floats will take an expression and return a new expression with every float wrapped in a sympy.UnevaluatedExpr, so that it is interpreted as a buffer rather than a parameter.

Extensions

Not every PyTorch or SymPy operation is supported -- just the ones that I found I've needed! There's a dictionary here that lists the supported operations. Feel free to submit PRs for any extra operations you think should be in by default. You can also use the extra_funcs argument to specify extra functions, including custom functions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympytorch-0.1.2.tar.gz (8.7 kB view details)

Uploaded Source

Built Distribution

sympytorch-0.1.2-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file sympytorch-0.1.2.tar.gz.

File metadata

  • Download URL: sympytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 8.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for sympytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 9fa4d0464fa3fcb59e65089a2595da6d7ce296528d57371d8bbb8cf7f362b06d
MD5 1e9c3919fa4042782cf9da163d1e29ea
BLAKE2b-256 0281c5b3d37dc8868ae2c6082c4b77502dcb4fbb2079d5d90fada0edd1ccdf19

See more details on using hashes here.

File details

Details for the file sympytorch-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: sympytorch-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for sympytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 15617a2af84d65e0399b42328d274f02a8962081778b0e2965d773cff89e8b9f
MD5 9081251627e9baceec8fae9a02d4c1bb
BLAKE2b-256 bed897f62a585b48921f491bd5197b76d2936f65352d71a8c5f4c5a6ceb7278d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page