Skip to main content

Keras Simple Attention

Project description

Keras Attention Mechanism

license dep1 dep2 Simple Keras Attention CI

Many-to-one attention mechanism for Keras.

Installation via pip

pip install attention

Import in the source code

from attention import Attention

# [...]

m = Sequential([
      LSTM(128, input_shape=(seq_length, 1), return_sequences=True),
      Attention(name='attention_weight'), # <--------- here.
      Dense(1, activation='linear')
])

Examples

Install the requirements before running the examples: pip install -r requirements.txt.

IMDB Dataset

In this experiment, we demonstrate that using attention yields a higher accuracy on the IMDB dataset. We consider two LSTM networks: one with this attention layer and the other one with a fully connected layer. Both have the same number of parameters for a fair comparison (250K).

Here are the results on 10 runs. For every run, we record the max accuracy on the test set for 10 epochs.

Measure No Attention (250K params) Attention (250K params)
MAX Accuracy 88.22 88.76
AVG Accuracy 87.02 87.62
STDDEV Accuracy 0.18 0.14

As expected, there is a boost in accuracy for the model with attention. It also reduces the variability between the runs, which is something nice to have.

Adding two numbers

Let's consider the task of adding two numbers that come right after some delimiters (0 in this case):

x = [1, 2, 3, 0, 4, 5, 6, 0, 7, 8]. Result is y = 4 + 7 = 11.

The attention is expected to be the highest after the delimiters. An overview of the training is shown below, where the top represents the attention map and the bottom the ground truth. As the training progresses, the model learns the task and the attention map converges to the ground truth.

Finding max of a sequence

We consider many 1D sequences of the same length. The task is to find the maximum of each sequence.

We give the full sequence processed by the RNN layer to the attention layer. We expect the attention layer to focus on the maximum of each sequence.

After a few epochs, the attention layer converges perfectly to what we expected.

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

attention-3.0.tar.gz (3.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

attention-3.0-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file attention-3.0.tar.gz.

File metadata

  • Download URL: attention-3.0.tar.gz
  • Upload date:
  • Size: 3.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for attention-3.0.tar.gz
Algorithm Hash digest
SHA256 8e6034dd2e6252f02939a69c9bd44a96836a6597b1c0d1373944c8ace8edc6ca
MD5 ac0157f34a871f5bf1bb1ea3e0050930
BLAKE2b-256 c9418290f9078df9db89deefd6f07efc75e0e2669230c2629d2bce7298f3ff96

See more details on using hashes here.

File details

Details for the file attention-3.0-py3-none-any.whl.

File metadata

  • Download URL: attention-3.0-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for attention-3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d5acea95b722cef73ebd1e1308f266029832ecc99e0eb812fce9249ed385c56e
MD5 cb41e19a341efe4af42e2b455223b30c
BLAKE2b-256 2d2798d7350db36a3537e24c9ec488d893b71092037f5c74e8984d01e9c1d316

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page