Paper - Pytorch
Project description
HSSS
Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling" but instead of using traditional SSMs were using Mambas. Basically the flow is single input -> low level mambas -> concat -> high level ssm -> multiple outputs.
I believe in this architecture alot as it segments local and global learning.
install
pip install hsss
usage
import torch
from hsss import LowLevelMamba, HSSS
# Reandom tensor
x = torch.randn(1, 10, 8)
# Low level model
mamba = LowLevelMamba(
dim=8, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)
# Low level model 2
mamba2 = LowLevelMamba(
dim=8, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)
# Low level mamba 3
mamba3 = LowLevelMamba(
dim=8, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)
# HSSS
hsss = HSSS(
layers=[mamba, mamba2, mamba3],
dim=12, # dimension of model
depth=3, # depth of model
dt_rank=2, # rank of model
d_state=2, # state of model
expand_factor=2, # expansion factor of model
d_conv=3, # convolution dimension of model
dt_min=0.001, # minimum time step of model
dt_max=0.1, # maximum time step of model
dt_init="random", # initialization method of model
dt_scale=1.0, # scaling factor of model
bias=False, # whether to use bias in model
conv_bias=True, # whether to use bias in convolution of model
pscan=True, # whether to use parallel scan in model
proj_layer=True,
)
# Forward pass
out = hsss(x)
print(out)
Citation
@misc{bhirangi2024hierarchical,
title={Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling},
author={Raunaq Bhirangi and Chenyu Wang and Venkatesh Pattabiraman and Carmel Majidi and Abhinav Gupta and Tess Hellebrekers and Lerrel Pinto},
year={2024},
eprint={2402.10211},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
License
MIT
Todo
-
Implement the chunking of the tokens by spliting it up the sequence dimension
-
Make the fusion projection layer dynamic and not use just a linear, ffn, or cross attention or even an output head.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
hsss-0.0.9.tar.gz
(10.8 kB
view hashes)
Built Distribution
hsss-0.0.9-py3-none-any.whl
(9.6 kB
view hashes)