Skip to main content

Model and inference code for beyond Transformer architectures

Project description

StripedHyena

Minimal implementation of a StripedHyena model.

About

One of the focus areas at Together Research is new architectures for long context, improved training, and inference performance over the Transformer architecture. Spinning out of a research program from our team and academic collaborators, with roots in signal processing-inspired sequence models, we are excited to introduce the StripedHyena models.

StripedHyena is the first alternative model competitive with the best open-source Transformers of similar sizes in short and long-context evaluations.

StripedHyena-Nous-7B (SH-N 7B) is our chat model for this release, and was developed with our collaborators at Nous Research.

SH-N 7B uses this prompt format: ### Instruction:\n{prompt}\n\n### Response:\n{response}

Model Architecture

StripedHyena is a hybrid architecture composed of multi-head, grouped-query attention and gated convolutions arranged in Hyena blocks, different from traditional decoder-only Transformers.

  • Costant memory decoding in Hyena blocks via representation of convolutions as state-space models (modal or canonical form), or as truncated filters.
  • Low latency, faster decoding and higher throughput than Transformers.
  • Improvement to training and inference-optimal scaling laws, compared to optimized Transformer architectures such as Llama-2.
  • Trained on sequences of up to 32k, allowing it to process longer prompts.

Quick Start

The most direct way to test StripedHyena models is via our playground, which includes a variety of architecture-specific optimizations.

Playground:

Standalone

Checkpoints

We provide a checkpoint for StripedHyena-Hessian 7B, our base model. Download pytorch-model.bin from the HuggingFace repository. As an alternative, we also provide HuggingFace compatible checkpoints for AutoClasses.

Environment Setup

To run our standalone StripedHyena implementation, you will need to install the packages in requirements.txt, as well as rotary and normalization kernels from flash_attn.

The easiest way to ensure all requirements are installed is to build a Docker image using Dockerfile, or follow the steps detailed in the Dockerfile itself in a different virtual environment. For example, to build a Docker image, run:

docker build --tag sh:test .

Installing the dependencies and kernels could take several minutes. Then run the container interactively with:

docker run -it --gpus all --network="host" --shm-size 900G -v=<path_to_this_repo>:/mnt:rw --rm sh:test

Environment Setup

Once the environment is set up, you will be able to generate text with:

python generate.py --config_path ./configs/7b-sh-32k-v1.yml \
--checkpoint_path <path_to_ckpt> --cached_generation \
--prompt_file ./test_prompt.txt

If you are generating with prompt.txt, set prefill_style: fft in the config. For very long prompts, you may want to opt for prefill_style: recurrence, which will be slower but use less memory.

If the installation was correct, test prompt will generate the following paragraph

The four species of hyenas are the striped hyena (Hyaena hyaena), the brown hyena (Parahyaena brunnea), the spotted hyena (Crocuta crocuta), and the aardwolf (Proteles cristata).\n\nThe striped hyena is the most widespread species, occurring in Africa, the Middle East, and Asia.

HuggingFace

We also provide an entry script to generate with StripedHyena models hosted on HuggingFace. The model ids are:

  • Base model: togethercomputer/StripedHyena-Hessian-7B
  • Chat model: togethercomputer/StripedHyena-Nous-7B

Choose your model id, then run the following command:

python generate_transformers.py --model-name <model_id> --input-file ./test_prompt.txt

Testing Correctness

We report lm-evaluation-harness (10-shot) scores to use as a proxy for (standalone) model correctness in your environment.

  • arc_challenge: 0.570 (acc norm)
  • hellaswag: 0.816 (acc norm)
  • winogrande: 0.735 (acc)

More extensive benchmarks results are provided in the blog post and on HuggingFace.

Optional Dependencies

The standalone implementation provides integration with some custom kernels for StripedHyena such as FlashFFTConv (see the model config 7b-sh-32k-v1.yml for more information). These additional kernels are not required to run the model.

Issues

Several issues can be resolved by reinstalling the latest version of flash_attn (pip freeze | grep flash-attn should return a version >= 2.0.0).

StripedHyena is a mixed precision model. Make sure to keep your poles and residues in float32 precision.

Cite

If have found the pretrained models or architecture useful for you research or application, consider citing:

@software{stripedhyena,
  title        = {{StripedHyena: Moving Beyond Transformers with Hybrid Signal Processing Models}},
  author       = { Poli, Michael and Wang, Jue and Massaroli, Stefano and Quesnelle, Jeffrey and Carlow, Ryan and Nguyen, Eric and Thomas, Armin},
  month        = 12,
  year         = 2023,
  url          = { https://github.com/togethercomputer/stripedhyena },
  doi          = { 10.57967/hf/1595 },
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

stripedhyena-0.2.2-py3-none-any.whl (30.2 kB view details)

Uploaded Python 3

File details

Details for the file stripedhyena-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: stripedhyena-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 30.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for stripedhyena-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 27286a9bcac5db72ad1412c3d4a91b0cf1b84f5334be30b250d6b545e1e63164
MD5 1efcf877453226540d1549ea2e5a6f41
BLAKE2b-256 04bd604b10267316caab887d304e2796d9a56f1ae46d1890f9d8f0ca0f838c31

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page