Skip to main content

Mamba state-space model

Project description

Mamba

Mamba

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
Paper: https://arxiv.org/abs/2312.00752

Installation

  • pip install causal-conv1d: an efficient implemention of a simple causal Conv1d layer used inside the Mamba block.
  • pip install mamba-ssm: the core Mamba package.

If pip complains about PyTorch versions, try passing --no-build-isolation to pip.

Other requirements:

  • Linux
  • NVIDIA GPU
  • PyTorch 1.12+
  • CUDA 11.6+

Usage

We expose several levels of interface with the Mamba model.

Selective SSM

Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).

Source: ops/selective_scan_interface.py.

Mamba Block

The main module of this repository is the Mamba architecture block wrapping the selective SSM.

Source: modules/mamba_simple.py.

Usage:

from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

Mamba Language Model

Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.

Source: models/mixer_seq_simple.py.

This is an example of how to integrate Mamba into an end-to-end neural network. This example is used in the generation scripts below.

Pretrained Models

Pretrained models are uploaded to HuggingFace: mamba-130m, mamba-370m, mamba-790m, mamba-1.4b, mamba-2.8b.

The models will be autodownloaded by the generation script below.

These models were trained on the Pile, and follow the standard model dimensions described by GPT-3 and followed by many open source models:

Parameters Layers Model dim.
130M 12 768
370M 24 1024
790M 24 1536
1.4B 24 2048
2.8B 32 2560

(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)

Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.

Evaluations

To run zero-shot evaluations of models (corresponding to Table 3 of the paper), we use the lm-evaluation-harness library.

  1. Pull the lm-evaluation-harness repo by git submodule update --init --recursive. We use the big-refactor branch.
  2. Install lm-evaluation-harness: pip install -e 3rdparty/lm-evaluation-harness
  3. Run evaluation with (more documentation at the lm-evaluation-harness repo):
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64

Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.

Inference

The script benchmarks/benchmark_generation_mamba_simple.py

  1. autoloads a model from the HuggingFace Hub,
  2. generates completions of a user-specified prompt,
  3. benchmarks the inference speed of this generation.

Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.

Examples

To test generation latency (e.g. batch size = 1) with different sampling strategies:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5

To test generation throughput with random prompts (e.g. large batch size):

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128

Citation

If you use this codebase, or otherwise found our work valuable, please cite Mamba:

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba_ssm-1.0.1.tar.gz (28.1 kB view details)

Uploaded Source

File details

Details for the file mamba_ssm-1.0.1.tar.gz.

File metadata

  • Download URL: mamba_ssm-1.0.1.tar.gz
  • Upload date:
  • Size: 28.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for mamba_ssm-1.0.1.tar.gz
Algorithm Hash digest
SHA256 809970f3187378c2546e8c9492a52149fd31b8de650e21a40394536df527de93
MD5 9433e27b7b8db9b1697d79fa3787c85d
BLAKE2b-256 c50561d5d28786f41c7d3cc671a232b9f73e24c82fd23e8317da2b2ed36a5c73

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page