Skip to main content

Turning SymPy expressions into PyTorch modules.

Project description

sympytorch

A micro-library as a convenience for turning SymPy expressions into PyTorch Modules.

SymPy floats (optionally) become trainable parameters. SymPy symbols are inputs to the Module.

Installation

pip install git+https://github.com/patrick-kidger/sympytorch.git

Example

import sympy, torch, sympytorch

x = sympy.symbols('x_name')
cosx = 1.0 * sympy.cos(x)
sinx = 2.0 * sympy.sin(x)
mod = sympytorch.SymPyModule(expressions=[cosx, sinx])

x_ = torch.rand(3)
out = mod(x_name=x_)  # out has shape (3, 2)

assert torch.equal(out[:, 0], x_.cos())
assert torch.equal(out[:, 1], 2 * x_.sin())
assert out.requires_grad  # from the two Parameters initialised as 1.0 and 2.0
assert {x.item() for x in mod.parameters()} == {1.0, 2.0}

API

sympytorch.SymPyModule(*, expressions, extra_funcs=None)

Where:

  • expressions is a list of SymPy expressions.
  • extra_funcs is a dictionary mapping from custom sympy.Functions to their PyTorch implementation. Defaults to no extra functions.

Instances of SymPyModule can be called, passing the values of the symbols as in the above example.

SymPyModule has a method .sympy(), which returns the corresponding list of SymPy expressions. (Which may not be the same as the expressions it was initialised with, if the values of its Parameters have been changed, i.e. have been learnt.)

Wrapping floats in sympy.UnevaluatedExpr will cause them not to be trained, by registering them as buffers rather than parameters.

sympytorch.hide_floats(expression)

As a convenience, hide_floats will take an expression and return a new expression with every float wrapped in a sympy.UnevaluatedExpr, so that it is interpreted as a buffer rather than a parameter.

Extensions

Not every PyTorch or SymPy operation is supported -- just the ones that I found I've needed! There's a dictionary here that lists the supported operations. Feel free to submit PRs for any extra operations you think should be in by default. You can also use the extra_funcs argument to specify extra functions, including custom functions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sympytorch-0.1.1.tar.gz (8.4 kB view details)

Uploaded Source

Built Distribution

sympytorch-0.1.1-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file sympytorch-0.1.1.tar.gz.

File metadata

  • Download URL: sympytorch-0.1.1.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.7.10

File hashes

Hashes for sympytorch-0.1.1.tar.gz
Algorithm Hash digest
SHA256 7aa8bd08e7673d1716de33cea341371d6960e080e5014d69c1cb361638788a16
MD5 02076cabf3a9ab33b716c444ff1e86d9
BLAKE2b-256 0033da6e3db1c7d7feb30cb7e51ea703a79b7f8ce2d25d9eabbe91c39d1e176a

See more details on using hashes here.

File details

Details for the file sympytorch-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: sympytorch-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.7.10

File hashes

Hashes for sympytorch-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ae2ab6c5917c453b524148869e1f0367967794830c9d0b7e13cac8f2fe66484
MD5 dfa980d3c53ec3363c866a822c2173f1
BLAKE2b-256 0fc92f04dce318945c2be589ab39e20ce3af7f2dc2d744e7bf0d5e563f4c6df0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page