Skip to main content

Differentiable signature calculations in JAX.

Project description

Signax: Computing signatures in JAX

Actions Status Codecov Status Documentation Status PyPI version PyPI platforms

Goal

To have a library that supports signature computation in JAX. See this paper to see how to adopt signatures in machine learning.

This implementation is inspired by patrick-kidger/signatory.

Examples

Basic usage

import jax
import jax.random as jrandom
import signax


key = jrandom.PRNGKey(0)
depth = 3

# compute signature for a single path
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signax.signature(path, depth)
# output is a list of array representing tensor algebra

# compute signature for batches (multiple) of paths
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
# new signax API can handle this case two
output = signax.signature(path, depth)

Integrate with the equinox library

import equinox as eqx
import jax.random as jrandom

from signax.module import SignatureTransform

# random generator key
key = jrandom.PRNGKey(0)
mlp_key, data_key = jrandom.split(key)

depth = 3
length, dim = 100, 3

# create a signature transform at the specified depth
signature_layer = SignatureTransform(depth=depth)

# stack a MLP layer after that
last_layer = eqx.nn.MLP(
    depth=1, in_size=3 + 3**2 + 3**3, width_size=4, out_size=1, key=mlp_key
)

model = eqx.nn.Sequential(layers=[signature_layer, last_layer])
x = jrandom.normal(shape=(length, dim), key=data_key)
output = model(x)

Also, check the notebooks in examples folder for some experiments that reproduce the results of the deep signature transforms paper.

Installation

Via pip

python3 -m pip install signax

Via source

git clone https://github.com/anh-tong/signax.git
cd signax
python3 -m pip install -v -e .

Parallelism

This implementation makes use of jax.vmap to perform the parallelism over batch dimension.

Paralelism over chunks of paths is done using jax.vmap as well.

A quick comparison can be found at in the notebook examples/compare.ipynb. Below plots are comparison of forward and backward pass in both GPU and CPU for path size=(32, 128, 8) and signature depth=5

Forward Backward

Why is using pure JAX good enough?

Because JAX make use of just-in-time (JIT) compilations with XLA, Signax can be reasonably fast.

We observe that the performance of this implementation is similar to Signatory in CPU and slightly better in GPU. It could be because of the optimized operators of XLA in JAX. Note that Signatory contains highly optimized C++ source code (PyTorch with Pybind11).

Acknowledgement

This repo is based on

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

signax-0.2.1.tar.gz (645.4 kB view details)

Uploaded Source

Built Distribution

signax-0.2.1-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file signax-0.2.1.tar.gz.

File metadata

  • Download URL: signax-0.2.1.tar.gz
  • Upload date:
  • Size: 645.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for signax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 bd0a9c43433f482abd1e56be3145727d9fbe61d1c5451f00d8fa36278ad72c2c
MD5 76024f957eb038829254a46da724e430
BLAKE2b-256 741c7fabf79b01ac07193d5f5fcf91af32d847754a11da70f7c094e14b6146cb

See more details on using hashes here.

File details

Details for the file signax-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: signax-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for signax-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 62593038cc7a6c0822c7529279f802dd7d44db509c7cf7cdc4b8d46ad82bbc04
MD5 885e50a7b4fa898352a8f51e97a66766
BLAKE2b-256 82a900d030e2649c77b0d6398da3c064bc3bd8eb7795970a3f91a31a388d62dc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page