Keras implementation of a NALU layer
Project description
Keras NALU (Neural Arithmetic Logic Units)
Keras implementation of a NALU layer (Neural Arithmetic Logic Units). See: https://arxiv.org/pdf/1808.00508.pdf.
Installation
pip install keras-nalu
Usage
from keras.layers import Input
from keras.models import Model
from keras.optimizers import RMSprop
from keras_nalu.nalu import NALU
# Your dataset
X_test = ... # Interpolation data
Y_test = ... # Interpolation data
X_validation = ... # Extrapolation data (validation)
Y_validation = ... # Extrapolation data (validation)
X_test = ... # Extrapolation data (test)
Y_test = ... # Extrapolation data (test)
# Hyper parameters
epoch_count = 1000
learning_rate = 0.05
sequence_len = 100
inputs = Input(shape=(sequence_len, ))
hidden = NALU(units=2)(inputs)
hidden = NALU(units=2)(hidden)
outputs = NALU(units=1)(hidden)
model = Model(inputs=inputs, outputs=outputs)
model.summary()
model.compile(loss='mse', optimizer=RMSprop(lr=learning_rate))
model.fit(
batch_size=256,
epochs=epoch_count,
validation_data=(X_validation, Y_validation),
x=X_train,
y=Y_train,
)
extrapolation_loss = model.evaluate(
batch_size=256,
x=X_test,
y=Y_test,
)
Options
cell
Cell to use in the NALU layer. May be 'a' (addition/subtraction), 'm' (multiplication/division/power), or None which, will apply a gating function to toggle between 'a' or 'm'.
- Default:
None
- Type:
?('a' | 'm' | None)
e
Epsilon value added to inputs in order to prevent calculating the log of zero.
- Default:
1e-7
- Type:
?float
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
keras-nalu-1.2.0.tar.gz
(4.0 kB
view hashes)
Built Distribution
Close
Hashes for keras_nalu-1.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5e0b705fe4957dbe7f8239dc0cf5abf8b99042afa4fa5036c930c1c6b7f973c |
|
MD5 | 2d088db6165b15953fd7acf86e5f5ed8 |
|
BLAKE2b-256 | b2cdc194bd983a547d13e4fd8f1b0f1d8c23647768b53a380914a68c54d22f39 |