Skip to main content

Attention mechanism for processing sequence data that considers the context for each timestamp

Project description

Keras Self-Attention
====================


.. image:: https://travis-ci.org/CyberZHG/keras-self-attention.svg
:target: https://travis-ci.org/CyberZHG/keras-self-attention
:alt: Travis


.. image:: https://coveralls.io/repos/github/CyberZHG/keras-self-attention/badge.svg?branch=master
:target: https://coveralls.io/github/CyberZHG/keras-self-attention
:alt: Coverage


Attention mechanism for processing sequence data that considers the context for each timestamp.


*
.. image:: https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif
:target: https://user-images.githubusercontent.com/853842/44248592-1fbd0500-a21e-11e8-9fe0-52a1e4a48329.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif
:target: https://user-images.githubusercontent.com/853842/44248591-1e8bd800-a21e-11e8-9ca8-9198c2725108.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif
:target: https://user-images.githubusercontent.com/853842/44248590-1df34180-a21e-11e8-8ff1-268217f466ba.gif
:alt:

*
.. image:: https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif
:target: https://user-images.githubusercontent.com/853842/44249018-8ba06d00-a220-11e8-80e3-802677b658ed.gif
:alt:

Install
-------

.. code-block:: bash

pip install keras-self-attention

Usage
-----

Basic
^^^^^

By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (\ ``attention_activation`` is the activation function of ``e_{t, t'}``\ ):

.. code-block:: python

import keras
from keras_self_attention import Attention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=10000,
output_dim=300,
mask_zero=True))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,
return_sequences=True)))
model.add(Attention(attention_activation='sigmoid'))
model.add(keras.layers.Dense(units=5))
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['categorical_accuracy'],
)
model.summary()

Local Attention
^^^^^^^^^^^^^^^

The global context may be too broad for one piece of data. The parameter ``attention_width`` controls the width of the local context:

.. code-block:: python

from keras_self_attention import Attention

Attention(
attention_width=15,
attention_activation='sigmoid',
name='Attention',
)

Multiplicative Attention
^^^^^^^^^^^^^^^^^^^^^^^^

You can use multiplicative attention by setting ``attention_type``\ :


.. image:: https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif
:target: https://user-images.githubusercontent.com/853842/44253887-a03a3080-a233-11e8-9d49-3fd7e622a0f7.gif
:alt:


.. code-block:: python

from keras_self_attention import Attention

Attention(
attention_width=15,
attention_type=Attention.ATTENTION_TYPE_MUL,
attention_activation=None,
kernel_regularizer=keras.regularizers.l2(1e-6),
use_attention_bias=False,
name='Attention',
)

Regularizer
^^^^^^^^^^^

To use the regularizer
.. image:: https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif
:target: https://user-images.githubusercontent.com/853842/44250188-f99b6300-a225-11e8-8fab-8dcf0d99616e.gif
:alt:
, the attention should be returned for calculating loss:

.. code-block:: python

import keras
from keras_self_attention import Attention

inputs = keras.layers.Input(shape=(None,))
embd = keras.layers.Embedding(input_dim=32,
output_dim=16,
mask_zero=True)(inputs)
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16,
return_sequences=True))(embd)
att, weights = Attention(return_attention=True,
attention_width=5,
attention_type=Attention.ATTENTION_TYPE_MUL,
kernel_regularizer=keras.regularizers.l2(1e-4),
bias_regularizer=keras.regularizers.l1(1e-4),
name='Attention')(lstm)
dense = keras.layers.Dense(units=5, name='Dense')(att)
model = keras.models.Model(inputs=inputs, outputs=[dense, weights])
model.compile(
optimizer='adam',
loss={'Dense': 'sparse_categorical_crossentropy', 'Attention': Attention.loss(1e-2)},
metrics={'Dense': 'categorical_accuracy'},
)
model.summary(line_length=100)
model.fit(
x=x,
y=[
numpy.zeros((batch_size, sentence_len, 1)),
numpy.zeros((batch_size, sentence_len, sentence_len))
],
epochs=10,
)

Load the Model
^^^^^^^^^^^^^^

Make sure to add ``Attention`` to custom objects and add ``attention_regularizer`` as well if the regularizer has been used:

.. code-block:: python

import keras

keras.models.load_model(model_path, custom_objects={
'Attention': Attention,
'attention_regularizer': Attention.loss(1e-2),
})

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-self-attention-0.0.14.tar.gz (5.1 kB view details)

Uploaded Source

File details

Details for the file keras-self-attention-0.0.14.tar.gz.

File metadata

  • Download URL: keras-self-attention-0.0.14.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.18.4 setuptools/28.8.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for keras-self-attention-0.0.14.tar.gz
Algorithm Hash digest
SHA256 b1cf7bd1172b7e186a48e0e6f54165f6a237802e17f6309b789431f946231835
MD5 3b361c105b8c26a234247e59edd575a5
BLAKE2b-256 11813cddb03574113b22904288439c75bd7afda04fdabe486486b283e983612d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page