Skip to main content

Attention mechanism for processing sequential data that considers the context for each timestamp

Project description

Keras Self-Attention

Travis Coverage PyPI Codacy Badge

Attention mechanism for processing sequential data that considers the context for each timestamp.

Install

pip install keras-self-attention

Usage

Basic

By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (attention_activation is the activation function of e_{t, t'}):

import keras
from keras_self_attention import SeqSelfAttention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=10000,
                                 output_dim=300,
                                 mask_zero=True))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,
                                                       return_sequences=True)))
model.add(SeqSelfAttention(attention_activation='sigmoid'))
model.add(keras.layers.Dense(units=5))
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['categorical_accuracy'],
)
model.summary()

Local Attention

The global context may be too broad for one piece of data. The parameter attention_width controls the width of the local context:

from keras_self_attention import SeqSelfAttention

SeqSelfAttention(
    attention_width=15,
    attention_activation='sigmoid',
    name='Attention',
)

Multiplicative Attention

You can use multiplicative attention by setting attention_type:

from keras_self_attention import SeqSelfAttention

SeqSelfAttention(
    attention_width=15,
    attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
    attention_activation=None,
    kernel_regularizer=keras.regularizers.l2(1e-6),
    use_attention_bias=False,
    name='Attention',
)

Regularizer

To use the regularizer, set attention_regularizer_weight to a positive number:

import keras
from keras_self_attention import SeqSelfAttention

inputs = keras.layers.Input(shape=(None,))
embd = keras.layers.Embedding(input_dim=32,
                              output_dim=16,
                              mask_zero=True)(inputs)
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16,
                                                    return_sequences=True))(embd)
att = SeqSelfAttention(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
                       kernel_regularizer=keras.regularizers.l2(1e-4),
                       bias_regularizer=keras.regularizers.l1(1e-4),
                       attention_regularizer_weight=1e-4,
                       name='Attention')(lstm)
dense = keras.layers.Dense(units=5, name='Dense')(att)
model = keras.models.Model(inputs=inputs, outputs=[dense])
model.compile(
    optimizer='adam',
    loss={'Dense': 'sparse_categorical_crossentropy'},
    metrics={'Dense': 'categorical_accuracy'},
)
model.summary(line_length=100)

Load the Model

Make sure to add SeqSelfAttention to custom objects:

import keras

keras.models.load_model(model_path, custom_objects=SeqSelfAttention.get_custom_objects())

Select Positions

When there are multiple inputs, the second input is considered as positions:

positions = keras.layers.Input(shape=(seq_len,), name='Input-Pos')
SeqSelfAttention(name='Attention')([lstm, positions])

History Only

Set history_only to True when only historical data could be used:

SeqSelfAttention(
    attention_width=3,
    history_only=True,
    name='Attention',
)

Multi-Head

Please refer to keras-multi-head.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-self-attention-0.34.0.tar.gz (12.0 kB view details)

Uploaded Source

File details

Details for the file keras-self-attention-0.34.0.tar.gz.

File metadata

  • Download URL: keras-self-attention-0.34.0.tar.gz
  • Upload date:
  • Size: 12.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.7.1 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for keras-self-attention-0.34.0.tar.gz
Algorithm Hash digest
SHA256 a21f6a3f63c9f6df24f2085667c345edcddd2c0fc82911b94973054c688b665f
MD5 6b93c545a28eb80996a0bdf3c2a8bfa7
BLAKE2b-256 8ec7ab6b511363ca59d0c1f655ae3bfd08c638c66b4f480f41a345b17164e840

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page