Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads
Project description
Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads
News 🔥
- [2023/09] Medusa v0.1 is released! 🎉
Introduction
Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads.
We aim to tackle the three pain points of popular acceleration techniques like speculative decoding:
- Requirement of a good draft model.
- System complexity.
- Inefficiency when using sampling-based genenration.
In a nutshell, we solve the challenges of speculative decoding with the following ideas:
- Instead of introducing a new model, we train multiple decoding heads on the same model.
- The training is parameter-efficient so that even GPU poor can do it. And since there is no additional model, there is no need to adjust the distributed computing setup.
- Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding.
In this initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings.
Contents
Installation
Method 1: With pip
pip install medusa-llm
Method 2: From source
git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .
Model Weights
Size | Chat Command | Hugging Face Repo |
---|---|---|
7B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3 |
FasterDecoding/medusa-vicuna-7b-v1.3 |
13B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3 |
FasterDecoding/medusa-vicuna-13b-v1.3 |
33B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3 |
FasterDecoding/medusa-vicuna-33b-v1.3 |
Inference
We currently support inference in the single GPU and batch size 1 setting, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks, please don't hesitate to reach out if you are interested in contributing to this effort.
You can use the following command for lauching a CLI interface:
python -m medusa.inference.cli --model [path of medusa model]
You can also pass --load-in-8bit
or --load-in-4bit
to load the base model in quantized format.
Training
For training, please install:
pip install -e ".[train]"
Prepare the data
We take a public version of the ShareGPT dataset, which is a subset of the Vicuna training data. For other models, you can use the corresponding training dataset.
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
Train the model
We follow the training setup from FastChat, but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use --load_in_8bit
or --load_in_4bit
to load the base model in quantized format.
torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
--data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
--bf16 True \
--output_dir test \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-3 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--lazy_preprocess True \
--medusa_num_heads 3 \
--medusa_num_layers 1
Push to Hugging Face Hub
You can use the following command to push your model to the Hugging Face Hub:
python -m medusa.hf_utils --folder [path of the model folder] --repo [name of the repo]
Citation
@misc{medusa,
author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Tri Dao},
title = {Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/FasterDecoding/Medusa}},
}
Codebase Guide
medusa/model/medusa_model.py
is the key file for Medusa. It contains the MedusaModel
class, which is a wrapper of the original model and the new heads. This class also has implementation of a streaming generation method. If you want to dive into the details of Medusa, this is the place to start.
We also provide some illustrative notebooks in notebooks/
to help you understand the codebase.
Contributing
We welcome community contributions to Medusa. If you have an idea for how to improve it, please open an issue to discuss it with us. When submitting a pull request, please ensure that your changes are well-tested. Please split each major change into a separate pull request. We also have a Roadmap summarizing our future plans for Medusa. Don't hesitate to reach out if you are interested in contributing to any of the items on the roadmap.
Acknowledgements
This codebase is influenced by amazing works from the community, including FastChat, TinyChat, vllm and many others.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file medusa-llm-0.1.tar.gz
.
File metadata
- Download URL: medusa-llm-0.1.tar.gz
- Upload date:
- Size: 36.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b35f8ca015091e2abd4397bfb3133cd06c245cdf98ba282b547f7f7f49cf0b76 |
|
MD5 | 9bc917d53a6bbebbb1e473e14c4069f4 |
|
BLAKE2b-256 | f991779b5b4f7fd0bcac9f16944550d26e2f4ece18e88de5547210f85a1b1904 |
File details
Details for the file medusa_llm-0.1-py3-none-any.whl
.
File metadata
- Download URL: medusa_llm-0.1-py3-none-any.whl
- Upload date:
- Size: 36.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e4a0f04dcc1cac018a103d32891dccf7055c92147e87ae051f92c34030da1ab3 |
|
MD5 | 22bca79194684f7e07d9c66305c8f18e |
|
BLAKE2b-256 | 2c731885e32c4bb0ed778c25abaf5120032f770352e789481dfee90a84aec110 |