Skip to main content

Attention mechanism for processing sequence data that considers the context for each timestamp

Project description

Keras Self-Attention
====================


.. image:: https://travis-ci.org/CyberZHG/keras-self-attention.svg
:target: https://travis-ci.org/CyberZHG/keras-self-attention
:alt: Travis


.. image:: https://coveralls.io/repos/github/CyberZHG/keras-self-attention/badge.svg?branch=master
:target: https://coveralls.io/github/CyberZHG/keras-self-attention
:alt: Coverage


Attention mechanism for processing sequence data that considers the context for each timestamp.


*
.. image:: https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif
:target: https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif
:target: https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif
:target: https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif
:target: https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif
:alt:

Install
-------

.. code-block:: bash

pip install keras-self-attention

Usage
-----

Basic
^^^^^

By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (\ ``attention_activation`` is the activation function of ``e_{t, t'}``\ ):

.. code-block:: python

import keras
from keras_self_attention import Attention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=10000,
output_dim=300,
mask_zero=True))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,
return_sequences=True)))
model.add(Attention(attention_activation='sigmoid'))
model.add(keras.layers.Dense(units=5))
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['categorical_accuracy'],
)
model.summary()

Local Attention
^^^^^^^^^^^^^^^

The global context may be too broad for one piece of data. The parameter ``attention_width`` controls the width of the local context:

.. code-block:: python

from keras_self_attention import Attention

Attention(
attention_width=15,
attention_activation='sigmoid',
name='Attention',
)

Multiplicative Attention
^^^^^^^^^^^^^^^^^^^^^^^^

You can use multiplicative attention by setting ``attention_type``\ :


.. image:: https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif
:target: https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif
:alt:


.. code-block:: python

from keras_self_attention import Attention

Attention(
attention_width=15,
attention_type=Attention.ATTENTION_TYPE_MUL,
attention_activation=None,
kernel_regularizer=keras.regularizers.l2(1e-6),
use_attention_bias=False,
name='Attention',
)

Regularizer
^^^^^^^^^^^

To use the regularizer
.. image:: https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif
:target: https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif
:alt:
, the attention should be returned for calculating loss:

.. code-block:: python

import keras
from keras_self_attention import Attention

inputs = keras.layers.Input(shape=(None,))
embd = keras.layers.Embedding(input_dim=32,
output_dim=16,
mask_zero=True)(inputs)
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16,
return_sequences=True))(embd)
att, weights = Attention(return_attention=True,
attention_width=5,
attention_type=Attention.ATTENTION_TYPE_MUL,
kernel_regularizer=keras.regularizers.l2(1e-4),
bias_regularizer=keras.regularizers.l1(1e-4),
name='Attention')(lstm)
dense = keras.layers.Dense(units=5, name='Dense')(att)
model = keras.models.Model(inputs=inputs, outputs=[dense, weights])
model.compile(
optimizer='adam',
loss={'Dense': 'sparse_categorical_crossentropy', 'Attention': Attention.loss(1e-2)},
metrics={'Dense': 'categorical_accuracy'},
)
model.summary(line_length=100)
model.fit(
x=x,
y=[
numpy.zeros((batch_size, sentence_len, 1)),
numpy.zeros((batch_size, sentence_len, sentence_len))
],
epochs=10,
)

Load the Model
^^^^^^^^^^^^^^

Make sure to add ``Attention`` to custom objects and add ``attention_regularizer`` as well if the regularizer has been used:

.. code-block:: python

import keras

keras.models.load_model(model_path, custom_objects={
'Attention': Attention,
'attention_regularizer': Attention.loss(1e-2),
})

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-self-attention-0.0.14.tar.gz (5.1 kB view hashes)

Uploaded source

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page