Skip to main content

Paper - Pytorch

Project description

Multi-Modality

HSSS

Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling" but instead of using traditional SSMs were using Mambas. Basically the flow is single input -> low level mambas -> concat -> high level ssm -> multiple outputs.

Paper link

I believe in this architecture alot as it segments local and global learning.

install

pip install hsss

usage

import torch
from hsss import LowLevelMamba, HSSS


# Reandom tensor
x = torch.randn(1, 10, 8)

# Low level model
mamba = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level model 2
mamba2 = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level mamba 3
mamba3 = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# HSSS
hsss = HSSS(
    layers=[mamba, mamba2, mamba3],
    dim=12,  # dimension of model
    depth=3,  # depth of model
    dt_rank=2,  # rank of model
    d_state=2,  # state of model
    expand_factor=2,  # expansion factor of model
    d_conv=3,  # convolution dimension of model
    dt_min=0.001,  # minimum time step of model
    dt_max=0.1,  # maximum time step of model
    dt_init="random",  # initialization method of model
    dt_scale=1.0,  # scaling factor of model
    bias=False,  # whether to use bias in model
    conv_bias=True,  # whether to use bias in convolution of model
    pscan=True,  # whether to use parallel scan in model
    proj_layer=True,
)


# Forward pass
out = hsss(x)
print(out)

Citation

@misc{bhirangi2024hierarchical,
      title={Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling}, 
      author={Raunaq Bhirangi and Chenyu Wang and Venkatesh Pattabiraman and Carmel Majidi and Abhinav Gupta and Tess Hellebrekers and Lerrel Pinto},
      year={2024},
      eprint={2402.10211},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

MIT

Todo

  • Implement the chunking of the tokens by spliting it up the sequence dimension

  • Make the fusion projection layer dynamic and not use just a linear, ffn, or cross attention or even an output head.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hsss-0.0.9.tar.gz (10.8 kB view hashes)

Uploaded Source

Built Distribution

hsss-0.0.9-py3-none-any.whl (9.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page