Skip to main content

An MLX implementation of ConvLSTM

Project description

PyPi Release

An Implementation for ConvLSTM in Apple's Array Framework, MLX

A Convolutional LSTM recurrent layer.

_conv_lstm_cell

This nn.Module computes the hidden and cell state for a time-step, expressed as:

$i_t = \sigma (W_{xi} \ast X_t + W_{hi} \ast H_{t-1} + W_{ci} \odot C_{t-1} + b_i)$
$f_t = \sigma (W_{xf} \odot X_t \ast H_{t-1} + W_{cf} \odot C_{t-1} + b_f)$
$C_t = f_t \odot X_t + i_t \odot tanh(W_{xc} \ast X_{t} + W_{hc} \ast H_{t-1} + b_c)$
$ o_t = \sigma(W_{xo} \ast X_t + W_{ho} \ast H_{t-1} + W_{co} \odot C_t + b_o )$
$H_t = o_t \odot tanh(C_t)$ \

Where $\sigma$ and $\odot$ represent the hyperbolic sigmoid function and Hadamard product respectively.

The expected input for this layer has shape NHWC or HWC where:

  • N is the optional batch dimension
  • H is the input's spatial height dimension
  • W is the input's spatial width dimension
  • C is the input's channel dimension

And returns a Tuple of the hidden state, $H_t$, and the cell state, $C_t$, each with shape NHWO.

Args:

in_channels (int): The number of input channels, C.
out_channels (int): The number of output channels, O.
kernel_size (int): The size of the convolution filters, must be odd to keep spatial dimensions with padding. Default: 5.
stride (Union[int, tuple] : The stride of the convolution. padding (Union[int, tuple] : Padding to add to the input for convolution. dilation (Union[int, tuple] : Dilation of the convolution. bias (bool): Whether the convolutional calculation should use biases or not. Default: True.

ConvLSTM

Unrolls a _conv_lstm_cell sequentially over time-steps.

The expected input for this layer has shape NLHWC or LHWC where:

  • N is the optional batch dimension
  • L is the length of the sequence
  • H is the input's spatial height dimension
  • W is the input's spatial width dimension
  • C is the input's channel dimension

Args: in_channels (int): The number of input channels, C.
out_channels (int): The number of output channels, O.
kernel_size (int): The size of the convolution filters, must be odd to keep spatial dimensions with padding. Default: 5.
bias (bool): Whether the convolutional calculation should use biases or not. Default: True.

The following features are yet to be implemented from initial release:

  • Bi-directionality - allows the conv-lstm to unroll both forwards and backwards across the sequence
  • Allow for stride customization
  • Allow for customizable padding along with modes 'same' and 'valid'

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

convlstm_mlx-0.1.1.tar.gz (3.9 kB view details)

Uploaded Source

Built Distribution

convlstm_mlx-0.1.1-py3-none-any.whl (4.3 kB view details)

Uploaded Python 3

File details

Details for the file convlstm_mlx-0.1.1.tar.gz.

File metadata

  • Download URL: convlstm_mlx-0.1.1.tar.gz
  • Upload date:
  • Size: 3.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for convlstm_mlx-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9f5991bc12b86af008791b2a40e9cfb9bd703d8f28d2212e59c406a2ee673e79
MD5 3d5a7e396f672fbb0c3bd9234c401280
BLAKE2b-256 63fb8cf19ead053fa3036eb722d280b87203d8a049d3fd04d2cfaaae56ea24c5

See more details on using hashes here.

Provenance

The following attestation bundles were made for convlstm_mlx-0.1.1.tar.gz:

Publisher: python-publish.yml on tomo-oga/convlstm-mlx

Attestations:

File details

Details for the file convlstm_mlx-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: convlstm_mlx-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 4.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for convlstm_mlx-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ebc2fa9444df0253134bedcb1056ce83989ef490afbec802a10ee0e469b32297
MD5 572931d5939e2e6d7794973b3b202653
BLAKE2b-256 02d9d1b7c99457c6b1ebd058aa1ad5f47a484bd915f0d16c8fad08009a11787c

See more details on using hashes here.

Provenance

The following attestation bundles were made for convlstm_mlx-0.1.1-py3-none-any.whl:

Publisher: python-publish.yml on tomo-oga/convlstm-mlx

Attestations:

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page