Skip to main content

Signature gated RNN

Project description

SigRNN: Signature-Enhanced Recurrent Neural Networks

This repository implements novel RNN layers that incorporate path signatures for enhanced time series modeling. The implementation is built on top of keras_sig and is compatible with Keras 3.0, supporting multiple backends (TensorFlow, JAX, and PyTorch).

Overview

SigRNN introduces two novel layer architectures:

  • SignatureLSTM: An LSTM variant where the forget gate is computed using path signatures
  • SignatureGRU: A GRU variant where the reset gate is computed using path signatures

The key idea is to leverage path signatures to enhance the gating mechanisms in traditional RNN architectures. The signature computations provide a richer representation of the temporal dynamics, potentially improving the model's ability to capture long-term dependencies.

Installation

pip install sig_rnn

Quick Start

from sig_rnn import SignatureLSTM, SignatureGRU
import keras

# Example with SignatureLSTM
model = keras.Sequential([
    keras.layers.Input(shape=(sequence_length, n_features)),
    SignatureLSTM(
        units=64,
        signature_depth=2,
        signature_input_size=5,
        return_sequences=True
    ),
    keras.layers.Dense(1)
])

# Example with SignatureGRU
model = keras.Sequential([
    keras.layers.Input(shape=(sequence_length, n_features)),
    SignatureGRU(
        units=64,
        signature_depth=2,
        signature_input_size=5,
        return_sequences=False
    ),
    keras.layers.Dense(1)
])

Layer Parameters

Common Parameters

  • units: Dimensionality of the output space
  • signature_depth: Maximum depth for signature computation (default: 2)
  • signature_input_size: Input dimension for signature computation (default: 5)
  • return_sequences: Whether to return the full sequence or just the last output
  • return_state: Whether to return states in addition to output
  • unroll_level: Level of unrolling for scan operations (default: 10)

SignatureLSTM

The SignatureLSTM modifies the standard LSTM by computing the forget gate using path signatures while maintaining the traditional computation for input, cell, and output gates. This allows the model to potentially capture more complex temporal patterns when deciding what information to forget.

SignatureGRU

The SignatureGRU modifies the standard GRU by computing the reset gate using path signatures while maintaining the traditional computation for the update gate and candidate activation. This enhances the model's ability to reset its memory based on more sophisticated temporal features.

Backend Compatibility

The package is compatible with all Keras 3.0 backends:

  • TensorFlow 2.x
  • JAX
  • PyTorch

However, for optimal performance, we recommend using JAX as the backend due to its efficient handling of the signature computations.

Note: While PyTorch backend is supported, JIT compilation is currently not available with PyTorch.

Example Usage

Here's a more complete example showing how to use SignatureLSTM for time series prediction:

import keras
from sig_rnn import SignatureLSTM

# Create a model for time series prediction
model = keras.Sequential([
    keras.layers.Input(shape=(100, 20)),  # 100 timesteps, 20 features
    SignatureLSTM(
        units=64,
        signature_depth=2,
        signature_input_size=5,
        return_sequences=True
    ),
    SignatureLSTM(
        units=32,
        signature_depth=2,
        signature_input_size=5,
        return_sequences=False
    ),
    keras.layers.Dense(16, activation='relu'),
    keras.layers.Dense(1)  # Single step prediction
])

# Compile and train
model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10)

Implementation Details

The implementation uses the streaming signature computation from keras_sig, which allows for efficient processing of sequential data. The signature values are normalized by time to ensure stable training dynamics. Both layers support return_sequences and return_state options, making them compatible with standard Keras RNN patterns.

Citation

If you use this package in your research, please cite our work:

[Citation information to be added upon paper release]

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sig_rnn-0.1.0.tar.gz (7.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sig_rnn-0.1.0-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file sig_rnn-0.1.0.tar.gz.

File metadata

  • Download URL: sig_rnn-0.1.0.tar.gz
  • Upload date:
  • Size: 7.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-49-generic

File hashes

Hashes for sig_rnn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f124742116550a26ac8298bd5e3024d2f435668b047cf7180b20a19a608fc763
MD5 c43027c30234a455e6a363821b8d994b
BLAKE2b-256 e72d009acd899e72be745529263ed77d5b2a71830cee117b58f9be4f8d931f95

See more details on using hashes here.

File details

Details for the file sig_rnn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: sig_rnn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.2 Linux/6.8.0-49-generic

File hashes

Hashes for sig_rnn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 afdfaaa7bcf59b23aacd726e907bbb41a06bf8ca21d19a4e1741eaee96a225e4
MD5 5c85367bccd02ef91c10b63c101fcd34
BLAKE2b-256 def1dc4836f1f5d93a3ff8096ff50bb58cfba394fea0465332db7fd647fb49ef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page