Skip to main content

Attention mechanism for processing sequence data that considers the context for each timestamp

Project description

Travis Coverage PyPI

Attention mechanism for processing sequence data that considers the context for each timestamp.

Install

pip install keras-self-attention

Usage

Basic

By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (attention_activation is the activation function of e_{t, t'}):

import keras
from keras_self_attention import Attention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=10000,
                                 output_dim=300,
                                 mask_zero=True))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,
                                                       return_sequences=True)))
model.add(Attention(attention_activation='sigmoid'))
model.add(keras.layers.Dense(units=5))
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['categorical_accuracy'],
)
model.summary()

Local Attention

The global context may be too broad for one piece of data. The parameter attention_width controls the width of the local context:

from keras_self_attention import Attention

Attention(
    attention_width=15,
    attention_activation='sigmoid',
    name='Attention',
)

Multiplicative Attention

You can use multiplicative attention by setting attention_type:

from keras_self_attention import Attention

Attention(
    attention_width=15,
    attention_type=Attention.ATTENTION_TYPE_MUL,
    attention_activation=None,
    kernel_regularizer=keras.regularizers.l2(1e-6),
    use_attention_bias=False,
    name='Attention',
)

Regularizer

To use the regularizer, the attention should be returned for calculating loss:

import keras
from keras_self_attention import Attention

inputs = keras.layers.Input(shape=(None,))
embd = keras.layers.Embedding(input_dim=32,
                              output_dim=16,
                              mask_zero=True)(inputs)
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16,
                                                    return_sequences=True))(embd)
att, weights = Attention(return_attention=True,
                         attention_width=5,
                         attention_type=Attention.ATTENTION_TYPE_MUL,
                         kernel_regularizer=keras.regularizers.l2(1e-4),
                         bias_regularizer=keras.regularizers.l1(1e-4),
                         name='Attention')(lstm)
dense = keras.layers.Dense(units=5, name='Dense')(att)
model = keras.models.Model(inputs=inputs, outputs=[dense, weights])
model.compile(
    optimizer='adam',
    loss={'Dense': 'sparse_categorical_crossentropy', 'Attention': Attention.loss(1e-2)},
    metrics={'Dense': 'categorical_accuracy'},
)
model.summary(line_length=100)
model.fit(
    x=x,
    y=[
        numpy.zeros((batch_size, sentence_len, 1)),
        numpy.zeros((batch_size, sentence_len, sentence_len))
    ],
    epochs=10,
)

Load the Model

Make sure to add Attention to custom objects and add attention_regularizer as well if the regularizer has been used:

import keras

keras.models.load_model(model_path, custom_objects={
    'Attention': Attention,
    'attention_regularizer': Attention.loss(1e-2),
})

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-self-attention-0.0.17.tar.gz (5.2 kB view details)

Uploaded Source

File details

Details for the file keras-self-attention-0.0.17.tar.gz.

File metadata

  • Download URL: keras-self-attention-0.0.17.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.18.4 setuptools/28.8.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.6.4

File hashes

Hashes for keras-self-attention-0.0.17.tar.gz
Algorithm Hash digest
SHA256 0c2c40b77102ce9c68b4cde14c66f78f6fdb93eb2a91ab06acadc5268c2aadb0
MD5 06744eacc3d2f0cdc26845bd862c2d64
BLAKE2b-256 8070799461e149104506bcae2dc469c18c608e8952440734888c58f112262eee

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page