Differentiable signature calculations in JAX.
Project description
Signax: Computing signatures in JAX
Goal
To have a library that supports signature computation in JAX. See this paper to see how to adopt signatures in machine learning.
This implementation is inspired by patrick-kidger/signatory.
Examples
Basic usage
import jax
import jax.random as jrandom
import signax
key = jrandom.PRNGKey(0)
depth = 3
# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signax.signature(path, depth)
# output is a list of array representing tensor algebra
# compute signature for batches (multiple) of paths
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
# new signax API can handle this case two
output = signax.signature(path, depth)
Integrate with the equinox library
import equinox as eqx
import jax.random as jrandom
from signax.module import SignatureTransform
# random generator key
key = jrandom.PRNGKey(0)
mlp_key, data_key = jrandom.split(key)
depth = 3
length, dim = 100, 3
# create a signature transform at the specified depth
signature_layer = SignatureTransform(depth=depth)
# stack a MLP layer after that
last_layer = eqx.nn.MLP(
depth=1, in_size=3 + 3**2 + 3**3, width_size=4, out_size=1, key=mlp_key
)
model = eqx.nn.Sequential(layers=[signature_layer, last_layer])
x = jrandom.normal(shape=(length, dim), key=data_key)
output = model(x)
Also, check the notebooks in examples folder for some experiments that
reproduce the results of the
deep signature transforms paper.
Installation
Via pip
python3 -m pip install signax
Via source
git clone https://github.com/anh-tong/signax.git
cd signax
python3 -m pip install -v -e .
Parallelism
This implementation makes use of jax.vmap to perform the parallelism over
batch dimension.
Paralelism over chunks of paths is done using jax.vmap as well.
A quick comparison can be found at in the notebook examples/compare.ipynb.
Below plots are comparison of forward and backward pass in both GPU and CPU for
path size=(32, 128, 8) and signature depth=5
| Forward | Backward |
|---|---|
|
|
|
|
|
|
Why is using pure JAX good enough?
Because JAX make use of just-in-time (JIT) compilations with XLA, Signax can be reasonably fast.
We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. Note that Signatory contains highly optimized C++ source code (PyTorch with Pybind11).
Acknowledgement
This repo is based on
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file signax-0.2.1.tar.gz.
File metadata
- Download URL: signax-0.2.1.tar.gz
- Upload date:
- Size: 645.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bd0a9c43433f482abd1e56be3145727d9fbe61d1c5451f00d8fa36278ad72c2c
|
|
| MD5 |
76024f957eb038829254a46da724e430
|
|
| BLAKE2b-256 |
741c7fabf79b01ac07193d5f5fcf91af32d847754a11da70f7c094e14b6146cb
|
File details
Details for the file signax-0.2.1-py3-none-any.whl.
File metadata
- Download URL: signax-0.2.1-py3-none-any.whl
- Upload date:
- Size: 11.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
62593038cc7a6c0822c7529279f802dd7d44db509c7cf7cdc4b8d46ad82bbc04
|
|
| MD5 |
885e50a7b4fa898352a8f51e97a66766
|
|
| BLAKE2b-256 |
82a900d030e2649c77b0d6398da3c064bc3bd8eb7795970a3f91a31a388d62dc
|