Skip to main content

simple tools

Project description

Desc

keras-attention-block is an extension for keras to add attention. It was born from lack of existing function to add attention inside keras. The module itself is pure Python with no dependencies on modules or packages outside the standard Python distribution and keras.

keywords:keras,deeplearning,attention

Feature

  • support one dimensional attention, that is to take in inputs whose dimensions are batch_size * time_step * hidden_size

  • support two dimensional attention, that is to take in inputs of dimensions are batch_size * X * Y * hidden_size

  • support self-attention, that is to take in tensors. Four well defined calculations are included : additive, multiplicative, dot-product based and as well as linear.

  • support attention, that is to take in two tensors. Three well defined calculations are included : additive, multiplicative and dot product based.

  • support attention. Three well defined calculations are included : additive, multiplicative and dot product based.

  • support multihead attention

  • support customized calculations of similarity between Key and Query

  • support customized calculations of Value

Example

from keras.layers import merge
from keras.layers.core import *
from keras.layers.recurrent import LSTM
from keras.layers import Convolution2D
from keras.models import *
from keras.layers.normalization import BatchNormalization
from keras_attention_block import *

INPUT_DIM = 32
TIME_STEPS = 20
SINGLE_ATTENTION_VECTOR = False
APPLY_ATTENTION_BEFORE_LSTM = False

inputs = Input(shape=(TIME_STEPS, INPUT_DIM))
attention_mul =  SelfAttention1DLayer(similarity="linear",dropout_rate=0.2)(inputs)#MyLayer((20,32))(inputs)#
lstm_units = 32
#attention_mul = LSTM(lstm_units, return_sequences=False)(attention_mul)
attention_mul = Flatten()(attention_mul)
output = Dense(1, activation='sigmoid')(attention_mul)
m = Model(inputs=[inputs], outputs=output)

m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print(m.summary())

train_data = np.random.random((1000,20,32))
train_lab = np.random.randint(0,2,1000)
m.fit(train_data,train_lab , epochs=1, batch_size=100 )

Install

  • python -m pip install keras_attention_block

Documentation

Documentation on Readthedocs.

TODO

  • 3D attention

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_attention_block-0.0.2.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

keras_attention_block-0.0.2-py3-none-any.whl (18.9 kB view details)

Uploaded Python 3

File details

Details for the file keras_attention_block-0.0.2.tar.gz.

File metadata

File hashes

Hashes for keras_attention_block-0.0.2.tar.gz
Algorithm Hash digest
SHA256 eefc3dbe925bc4faac225405c3b29daf0699f51174011fac9dd3e2e097a70338
MD5 1f7b29f95d3b314e0e52bcf775d41c28
BLAKE2b-256 33d11115a726281f49e3bf5c38b65b788485a133cdcb78d1230918e766738b89

See more details on using hashes here.

File details

Details for the file keras_attention_block-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for keras_attention_block-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0fed55018ba5c51523d257c5a3a383020879ad42b48687a22be3fe9d9b8ac370
MD5 7dbfe2466d12d702b3677a05d770ddec
BLAKE2b-256 7fbae2869c827348e609a5bff3eb9b2078ff5c52e5fa7a1f33af96800c5d6778

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page