Symbolic Expressions in PyTorch
Project description
Fast, optimisable, symbolic expressions in PyTorch.
>>> from symtorch import symtorchify
>>> f = symtorchify("x**2 + 2.5*x + 1.7")
>>> f
x²+2.5x+1.7
>>> len(list(f.parameters()))
2
>>> import torch
>>> f.evalf({"x": torch.tensor(2.0)})
tensor([10.7000], grad_fn=<AddBackward0>)
Installation
pip install symtorch
Features and Documentation
What about SymPyTorch?
This package attempts to supersede the amazing Patrick Kidger's original SymPyTorch. Useful features improvements here are:
- implementations of
state_dict
andload_state_dict
for allSymTorch
objects, allowing for automated saving and loading via the native PyTorch mechanisms - plays nicely with TorchScript, allowing for integration into C++ code
- a
SymbolAssignment
helper class to enable "drag-and-drop" replace of existing NN components with symbolic ones:
>>> model = nn.Sequential(
SymbolAssignment(["a", "b"]),
symtorchify("3*a + b")
)
>>> model(torch.tensor([[1, 2], [3, 4]]))
tensor([[ 5.],
[13.]], grad_fn=<AddBackward0>)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
symtorch-0.0.0.tar.gz
(6.0 kB
view hashes)