Skip to main content

Paper - Pytorch

Project description

Multi-Modality

HSSS

Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling" but instead of using traditional SSMs were using Mambas. Basically the flow is single input -> low level mambas -> concat -> high level ssm -> multiple outputs.

Paper link

I believe in this architecture alot as it segments local and global learning.

install

pip install hsss

usage

import torch
from hsss import LowLevelMamba, HSSS


# Reandom tensor
x = torch.randn(1, 10, 8)

# Low level model
mamba = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level model 2
mamba2 = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# Low level mamba 3
mamba3 = LowLevelMamba(
    dim=8,  # dimension of input
    depth=6,  # depth of input
    dt_rank=4,  # rank of input
    d_state=4,  # state of input
    expand_factor=4,  # expansion factor of input
    d_conv=6,  # convolution dimension of input
    dt_min=0.001,  # minimum time step of input
    dt_max=0.1,  # maximum time step of input
    dt_init="random",  # initialization method of input
    dt_scale=1.0,  # scaling factor of input
    bias=False,  # whether to use bias in input
    conv_bias=True,  # whether to use bias in convolution of input
    pscan=True,  # whether to use parallel scan in input
)


# HSSS
hsss = HSSS(
    layers=[mamba, mamba2, mamba3],
    dim=12,  # dimension of model
    depth=3,  # depth of model
    dt_rank=2,  # rank of model
    d_state=2,  # state of model
    expand_factor=2,  # expansion factor of model
    d_conv=3,  # convolution dimension of model
    dt_min=0.001,  # minimum time step of model
    dt_max=0.1,  # maximum time step of model
    dt_init="random",  # initialization method of model
    dt_scale=1.0,  # scaling factor of model
    bias=False,  # whether to use bias in model
    conv_bias=True,  # whether to use bias in convolution of model
    pscan=True,  # whether to use parallel scan in model
    proj_layer=True,
)


# Forward pass
out = hsss(x)
print(out)

Citation

@misc{bhirangi2024hierarchical,
      title={Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling}, 
      author={Raunaq Bhirangi and Chenyu Wang and Venkatesh Pattabiraman and Carmel Majidi and Abhinav Gupta and Tess Hellebrekers and Lerrel Pinto},
      year={2024},
      eprint={2402.10211},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

MIT

Todo

  • Implement the chunking of the tokens by spliting it up the sequence dimension

  • Make the fusion projection layer dynamic and not use just a linear, ffn, or cross attention or even an output head.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hsss-0.0.9.tar.gz (10.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hsss-0.0.9-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file hsss-0.0.9.tar.gz.

File metadata

  • Download URL: hsss-0.0.9.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/23.3.0

File hashes

Hashes for hsss-0.0.9.tar.gz
Algorithm Hash digest
SHA256 3b6160a3e2bd55bb62652fcb4977d4f4e4d0e0ce4be02d022ac60585711a5d4a
MD5 bf0288a685a56d42e452ee016d38c96e
BLAKE2b-256 0a44cd2eb140df1a247f2abe628bd3c53c0e0ed3e82a021572e6d9e1e996ec76

See more details on using hashes here.

File details

Details for the file hsss-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: hsss-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/23.3.0

File hashes

Hashes for hsss-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 f8f2124494e91f303777f9f4203e199575011de3af47d1a9eee1d072cc871ce8
MD5 3e5fd6499e0e792343a355596887de92
BLAKE2b-256 25fad03397a1b331e573aa72554800ba64c5440bb1288eb34d71dea75d2c1c6d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page