Path Signature in Pure Keras
Project description
keras_sig: The most Efficient and Easy Path Signature computation
This package started as backend-agnostic Keras implementation of path signature computations, focusing on simplicity and ease of integration. Since we proposed a GPU-optimized computation methods that leverages fully parallel operations, it has become the fastest and most efficient path signature computation package available at date. This method is available either in full Keras for model training, but also as a standalone JAX function for direct computation.
Overview
keras_sig provides path signature computations as a Keras layer. It aims to offer:
- Native Keras implementation supporting all backends (JAX, PyTorch, TensorFlow)
- Simple integration within Keras models
- Pure Python implementation avoiding C++ dependencies
- Consistent API across different backends
- GPU-optimized computation for faster training
The package builds upon several key projects in the signature computation ecosystem:
Historical Context
- iisignature (repo): The foundational C++ implementation providing highly optimized signature computations with a python wrapper
- signatory (repo): A PyTorch-specific implementation using C++ level optimizations for GPU acceleration
- iisignature-tensorflow-2 (repo): An attempt at wrapping iisignature for TensorFlow 2, which faced limitations with model compilation
- signax (repo): A breakthrough pure JAX implementation showing that C++ optimization could be avoided
- keras_sig (this package): Bringing the pure Python approach to all Keras backends and optimizing further the computation for the GPU.
Installation
pip install keras_sig
Or install from source:
git clone https://github.com/yourusername/keras_sig
cd keras_sig
pip install -e .
Quick Start
Basic usage with Keras:
import keras
from keras_sig import SigLayer
model = keras.Sequential([
keras.layers.Input(shape=(timesteps, features)),
SigLayer(depth=3, stream=False, gpu_optimized=True), # Enable GPU optimization
keras.layers.Dense(output_dim)
])
Direct JAX computation (fastest option):
from keras_sig import jax_gpu_signature
# Pre-compiled GPU-optimized computation
signatures = jax_gpu_signature(paths, depth=3, stream=False)
Performance & Implementation Options
Computation Methods
-
GPU-Optimized (Recommended when GPU available)
- Uses parallel operations instead of loops
- 5x faster than standard implementation
- Higher memory usage
- Enable with
gpu_optimized=Trueor usejax_gpu_signature
-
Standard Implementation
- Loop-based computation with scan operations
- Lower memory footprint
- Better for CPU-only systems
- Default when GPU unavailable
Performance Benchmarks
All benchmarks run on AMD EPYC-7302P 16-cores with RTX-3090.
Forward Pass (128 batch, 100 sequence, 5 features, depth 4)
| Backend | Version | GPU Time | CPU Time |
|---|---|---|---|
| JAX | Pure Jax-GPU function | 163µs | 46.5ms |
| JAX | keras Standard | 713ms | 378ms |
| JAX | keras GPU-optimized | - | 80.5ms |
| JAX | signax | 668µs | 11.7ms |
| TensorFlow | keras GPU-optimized | 55.2ms | 180ms |
| TensorFlow | keras Standard | 375ms | 317ms |
| Torch | keras GPU-optimized | 2.84ms | 50.6ms |
| Torch | keras Standard | 92.4ms | 91.4ms |
| None | iisignature | 36.4ms | 36.4ms |
Here the Keras version are not performing optimally as direct Jax function because the keras operation are not runned on GPU nor compiled with jit. This phase is only happening at training time. However we can easily compare the performance of the Pure Jax function with signax and iisignature and see that our proposed approach is the fastest when a GPU is available. When no GPU is available, the standard version is very similar to the signax implementation.
Training Performance
Test conditions: We created a model following the SigKAN paper the following way
model = keras.Sequential([
Input(shape=X.shape[1:]),
Dense(7),
SigDense(10, depth, SigLayer),
Flatten(),
Dense(10, 'relu'),
Dense(n_ahead),
])
and trained it with jit_compilation enable when possible for 10 epochs with Adam optimizer on randomly generated datas.
Long Sequences (length=500)
| Backend | Version | Compile Time (GPU) | Compile Time (CPU) | Step Time (GPU) | Step Time (CPU) |
|---|---|---|---|---|---|
| JAX | GPU-opt | 5s | 25s | 2ms | 213ms |
| JAX | Standard | 7s | 14s | 14ms | 108ms |
| JAX | Signax | 6s | 12s | 14ms | 83ms |
| TensorFlow | GPU-opt | 9s | 26s | 2ms | 214ms |
| TensorFlow | Standard | Compile fail | Compile fail | - | - |
| TensorFlow | iisignature | No compile | No compile | 340-345ms | 340-345ms |
| Torch | GPU-opt | 53s | 26s | 21ms | 218ms |
| Torch | Standard | No compile | No compile | 590ms | 643ms |
Short Sequences (length=20)
| Backend | Version | Compile Time (GPU) | Compile Time (CPU) | Step Time (GPU) | Step Time (CPU) |
|---|---|---|---|---|---|
| JAX | GPU-opt | 4s | 6s | 1ms | 19ms |
| JAX | Standard | 8s | 8s | 1ms | 9ms |
| JAX | signax | 5s | 4s | 2ms | 6ms |
| TensorFlow | GPU-opt | 4s | 13s | 1ms | 102ms |
| TensorFlow | Standard | 19s | 14s | 2ms | 28ms |
| TensorFlow | iisignature | No compile | No compile | 27ms | 27ms |
| Torch | GPU-opt | 9s | 8s | 21ms | 17ms |
| Torch | Standard | No compile | No compile | 38ms | 31ms |
Key Findings:
- Pure JAX GPU-optimized version is fastest for forward pass (4x faster than signax)
- GPU-optimized variants excel with GPU availability across all backends
- For training:
- JAX: Best balance of compilation/execution
- TensorFlow: GPU-optimized version required for long sequences
- PyTorch: Longer compilation but good runtime with GPU-optimization
- Standard implementations struggle with:
- PyTorch: Compilation issues
- TensorFlow: Long sequence compilation
- All backends: Slower execution without GPU optimization
Implementation Recommendations
-
JAX + GPU (Best Overall)
- Use pure JAX implementation for forward pass
- Use GPU-optimized SigLayer for training
-
PyTorch + GPU
- Use GPU-optimized version only
- Expect longer compilation times
-
TensorFlow + GPU
- Use GPU-optimized version
- Avoid standard version for long sequences
-
CPU-Only Systems
- JAX standard implementation offers best balance
- GPU-optimized versions still usable but with performance penalty
Features
Currently implements:
- Standard signature computations
- Support for both streaming and non-streaming modes
- Configurable signature depth
- Backend-agnostic implementation
Not yet implemented (available in other packages):
- Log signatures
- Lyndon words
- Other advanced signature computations
Citations
If using this package, please cite both this work and the foundational packages that inspired it:
@article{reizenstein2017iisignature,
title={iisignature: A python package for computing iterated-integral signatures},
author={Reizenstein, Jeremy},
journal={Journal of Open Source Software},
volume={2},
number={10},
pages={189},
year={2017}
}
@article{kidger2021signatory,
title={Signatory: differentiable computations of the signature and logsignature transforms, on both CPU and GPU},
author={Kidger, Patrick and Lyons, Terry},
journal={International Conference on Learning Representations},
year={2021}
}
@software{signax2024github,
author = {Anh Tong},
title = {signax: Path Signatures in JAX},
url = {https://github.com/anh-tong/signax},
year = {2024},
}
@software{genet2024iisignaturetf2,
author = {Remi Genet, Hugo Inzirillo},
title = {iisignature-tensorflow-2: TensorFlow 2 Wrapper for iisignature},
url = {https://github.com/remigenet/iisignature-tensorflow-2},
year = {2024},
}
Contributing
Contributions are welcome! Feel free to submit issues and pull requests.
Would you like me to adjust any section or add more details?
Shield:
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keras_sig-1.0.1.tar.gz.
File metadata
- Download URL: keras_sig-1.0.1.tar.gz
- Upload date:
- Size: 12.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-48-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bb8e9c7583d62fcda4d458fe1dca986ba1ec32567f32274030c908ae24b6b040
|
|
| MD5 |
be971e3b92806cccef336ee85cabcbc8
|
|
| BLAKE2b-256 |
9d9e56a06573163a8b3ae34ee81a30f98b5f548c8400d052e886189aedb262a9
|
File details
Details for the file keras_sig-1.0.1-py3-none-any.whl.
File metadata
- Download URL: keras_sig-1.0.1-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-48-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cd32c46d1c6e0efc7c26463c96b86cfcf3faa8bc2face41968ce8c67272f743
|
|
| MD5 |
f8803005fbff700641d69ddfb0411a1b
|
|
| BLAKE2b-256 |
b1f58ea544fba8e26e47a982fb05b7c86bdda4b80a336617dbb0d8a2560a7209
|