Skip to main content

Custom Keras layers for implementing multi-dimensional recurrent neural networks (MDRNNs)

Project description

Multi-Directional Multi-Dimensional Recurrent Neural Networks

A library built on top of TensorFlow implementing the model described in Alex Graves's paper https://arxiv.org/pdf/0705.2011.pdf. The library comes with a set of custom Keras layers. Each layer can be seamlessly used in Keras to build a model and train it as usual.

Status: under development

This repository is in its early stages. The code presented here is not stable yet and it wasn't extensively tested. Use it at your own risk

Features

Layers available now:

  • MDRNN: layer analogous to Keras SimpleRNN layer for processing multi-dimensional inputs
  • MDLSTM: analogous to Keras LSTM layer
  • MultiDirectional: layer-wrapper analogous to Keras Bidirectional for creating multi-directional multi-dimensional RNN

Layers currently under development (coming soon):

  • MDGRU: analogous to Keras GRU layer

Additional features:

  • easy to use with Keras
  • Keras-like API for each layer
  • option to choose order/direction in which to process inputs
  • computations are run on CPU

Installation

Install the package from PyPI:

pip install mdrnn

Alternatively, clone the repository and install dependencies:

git clone <repo_url>
cd <repo_directory>
pip install -r requirements.txt

Quick Start

Create a 2-dimensional RNN:

from mdrnn import MDRNN, MDLSTM, MultiDirectional
import numpy as np
import tensorflow as tf
rnn = MDRNN(units=16, input_shape=(5, 4, 10), activation='tanh', return_sequences=True)
output = rnn(np.zeros((1, 5, 4, 10)))

Build a Keras model consisting of 1 MDRNN layer and train it:

model = tf.keras.Sequential()
model.add(MDRNN(units=16, input_shape=(2, 3, 6), activation='tanh'))
model.add(tf.keras.layers.Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy', metrics=['acc'])
model.summary()
x = np.zeros((10, 2, 3, 6))
y = np.zeros((10, 10,))
model.fit(x, y)

Similarly, create and train a multi-directional MDRNN

x = np.zeros((10, 2, 3, 6))
y = np.zeros((10, 40,))

model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(2, 3, 6)))
model.add(MultiDirectional(MDRNN(10, input_shape=[2, 3, 6])))

model.compile(loss='categorical_crossentropy', metrics=['acc'])
model.summary()

model.fit(x, y, epochs=1)

Similarly, create and train a multi-directional multi-dimensional LSTM (MDLSTM)

x = np.zeros((10, 2, 3, 6))
y = np.zeros((10, 40,))

model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(2, 3, 6)))
model.add(MultiDirectional(MDLSTM(10, input_shape=[2, 3, 6])))

model.compile(loss='categorical_crossentropy', metrics=['acc'])
model.summary()

model.fit(x, y, epochs=1)

Requirements

  • TensorFlow version >= 2.0

References

[1] A. Graves, S. Ferńandez, and J. Schmidhuber. Multidimensional recurrent neural networks.

[2] A. Graves and J. Schmidhuber. Offline Handwriting Recognition with Multidimensional Recurrent Neural Networks.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mdrnn-0.3.0.tar.gz (9.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mdrnn-0.3.0-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file mdrnn-0.3.0.tar.gz.

File metadata

  • Download URL: mdrnn-0.3.0.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.6.9

File hashes

Hashes for mdrnn-0.3.0.tar.gz
Algorithm Hash digest
SHA256 101da97b0e21b661bcafde9b603bcc39d8ac685af34bcd6a5d0f263d133840c1
MD5 5eadcf2d44a3f02e66f78a636817a193
BLAKE2b-256 8d4d2b9b732627f110d0f493bd5abac2a6fca6a0b9c399f075d1063c0d97e3ca

See more details on using hashes here.

File details

Details for the file mdrnn-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: mdrnn-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.6.9

File hashes

Hashes for mdrnn-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e577721857f9cf8ef8138c713d22bf1fae737cd131d1d37f7ef9559392da0b24
MD5 85bfbd67d512ce6b873c8f391ac71b0b
BLAKE2b-256 7f222cbe86d35408ad95c82bf9d3670cbe8b25f80eda7b1d945fecbdd2f2ed3c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page