Skip to main content

An MLX implementation of ConvLSTM

Project description

An Implementation for ConvLSTM in Apple's Array Framework mlx

A Convolutional LSTM recurrent layer.

_conv_lstm_cell

This Module computes the hidden and cell state for a time-step, expressed as:

$i_t = \sigma (W_{xi} \ast X_t + W_{hi} \ast H_{t-1} + W_{ci} \odot C_{t-1} + b_i)$
$f_t = \sigma (W_{xf} \odot X_t \ast H_{t-1} + W_{cf} \odot C_{t-1} + b_f)$
$C_t = f_t \odot X_t + i_t \odot tanh(W_{xc} \ast X_{t} + W_{hc} \ast H_{t-1} + b_c)$
$o_t = \sigma(W_{xo} \ast X_t + W_{ho} \ast H_{t-1} + W_{co} \odot C_t + b_o$
$H_t = o_t \odot tanh(C_t)$ \

Where $\sigma$ and $\odot$ represent the hyperbolic sigmoid function and Hadamard product respectively.

The expected input for this layer has shape NHWC or HWC where:

  • N is the optional batch dimension
  • H is the input's spatial height dimension
  • W is the input's spatial width dimension
  • C is the input's channel dimension

And returns a Tuple of the hidden state, $H_t$, and the cell state, $C_t$, each with shape NHWO.

Args:

in_channels (int): The number of input channels, C.
out_channels (int): The number of output channels, O.
kernel_size (int): The size of the convolution filters, must be odd to keep spatial dimensions with padding. Default: 5.
stride (Union[int, tuple] : The stride of the convolution. padding (Union[int, tuple] : Padding to add to the input for convolution. dilation (Union[int, tuple] : Dilation of the convolution. bias (bool): Whether the convolutional calculation should use biases or not. Default: True.

ConvLSTM

Unrolls a _conv_lstm_cell sequentially over time-steps.

The expected input for this layer has shape NLHWC or LHWC where:

  • N is the optional batch dimension
  • L is the length of the sequence
  • H is the input's spatial height dimension
  • W is the input's spatial width dimension
  • C is the input's channel dimension

Args: in_channels (int): The number of input channels, C.
out_channels (int): The number of output channels, O.
kernel_size (int): The size of the convolution filters, must be odd to keep spatial dimensions with padding. Default: 5.
bias (bool): Whether the convolutional calculation should use biases or not. Default: True.

The following features are yet to be implemented from initial release:
[] Bi-directionality - allows the conv-lstm to unroll both forwards and backwards across the sequence [] Allow for stride customization [] Allow for customizable padding along with modes 'same' and 'valid'

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convlstm_mlx-0.1.0.tar.gz (3.8 kB view details)

Uploaded Source

Built Distribution

convlstm_mlx-0.1.0-py3-none-any.whl (4.3 kB view details)

Uploaded Python 3

File details

Details for the file convlstm_mlx-0.1.0.tar.gz.

File metadata

  • Download URL: convlstm_mlx-0.1.0.tar.gz
  • Upload date:
  • Size: 3.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for convlstm_mlx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 cd811efe950377d616a89ca23fd556a0c157bb8a52bd6a603b3eda618f279d3f
MD5 69748af1cfd5127129c83cd8627d1185
BLAKE2b-256 a971937de5ad203e40772bf29d60cf99490a0f15ddf283414f908a3b049f49eb

See more details on using hashes here.

Provenance

The following attestation bundles were made for convlstm_mlx-0.1.0.tar.gz:

Publisher: python-publish.yml on tomo-oga/convlstm-mlx

Attestations:

File details

Details for the file convlstm_mlx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: convlstm_mlx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 4.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for convlstm_mlx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 361efc2037d231e7f0a4501cb3c716e634827eda4e742909a3c0e3f55cdd5817
MD5 d4749aed1d02572741c76e196c41aaa9
BLAKE2b-256 5d157124c185d3bd5dfc9c82c4b4f1fc3b267f84a6fab8f5506ebdbe55d64b3c

See more details on using hashes here.

Provenance

The following attestation bundles were made for convlstm_mlx-0.1.0-py3-none-any.whl:

Publisher: python-publish.yml on tomo-oga/convlstm-mlx

Attestations:

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page