Skip to main content

JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome).

Project description

JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome).

About

JetStream is a fast library for LLM inference and serving on TPUs (and GPUs in future -- PRs welcome).

Documentation


JetStream MaxText Inference on v5e Cloud TPU VM User Guide

Outline

  1. Prerequisites: Prepare your GCP project and connect to Cloud TPU VM
  2. Download the JetStream and MaxText github repository
  3. Setup your MaxText JetStream environment
  4. Convert Model Checkpoints
  5. Run the JetStream MaxText server
  6. Send a test request to the JetStream MaxText server
  7. Run benchmarks with the JetStream MaxText server
  8. Clean up

Prerequisites: Prepare your GCP project and connect to Cloud TPU VM

Follow the steps in Manage TPU resources | Google Cloud to create a Cloud TPU VM (Recommend TPU type: v5litepod-8) and connect to the Cloud TPU VM.

Step 1: Download JetStream and the MaxText github repository

git clone -b jetstream-v0.2.0 https://github.com/google/maxtext.git
git clone -b v0.2.0 https://github.com/google/JetStream.git

Step 2: Setup MaxText

# Create a python virtual environment for the demo.
sudo apt install python3.10-venv
python -m venv .env
source .env/bin/activate

# Setup MaxText.
cd maxtext/
bash setup.sh

Step 3: Convert Model Checkpoints

You can run the JetStream MaxText Server with Gemma and Llama2 models. This section describes how to run the JetStream MaxText server with various sizes of these models.

Use a Gemma model checkpoint

  • You can download a Gemma checkpoint from Kaggle.
  • After downloading checkpoints, copy them to your GCS bucket at $CHKPT_BUCKET.
    • gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    • Please refer to the conversion script for an example of $CHKPT_BUCKET.
  • Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}

# For gemma-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}

Note: For more information about the Gemma model and checkpoints, see About Gemma.

Use a Llama2 model checkpoint

  • You can use a Llama2 checkpoint you have generated or one from the open source community.
  • After downloading checkpoints, copy them to your GCS bucket at $CHKPT_BUCKET.
    • gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    • Please refer to the conversion script for an example of $CHKPT_BUCKET.
  • Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}

# For llama2-7b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}

# For llama2-13b
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}

Note: For more information about the Llama2 model and checkpoints, see About Llama2.

Step4: Run the JetStream MaxText server

Create model config environment variables for server flags

You can export the following environment variables based on the model you used.

  • You can copy and export the UNSCANNED_CKPT_PATH from the model_ckpt_conversion.sh output.

Create Gemma-7b environment variables for server flags

  • Configure the flags passing into the JetStream MaxText server
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Create Llama2-7b environment variables for server flags

  • Configure the flags passing into the JetStream MaxText server
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=6

Create Llama2-13b environment variables for server flags

  • Configure the flags passing into the JetStream MaxText server
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=2

Run the following command to start the JetStream MaxText server

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server flag descriptions:

  • tokenizer_path: file path to a tokenizer (should match your model)
  • load_parameters_path: Loads the parameters (no optimizer states) from a specific directory
  • per_device_batch_size: decoding batch size per device (1 TPU chip = 1 device)
  • max_prefill_predict_length: Maximum length for the prefill when doing autoregression
  • max_target_length: Maximum sequence length
  • model_name: Model name
  • ici_fsdp_parallelism: The number of shards for FSDP parallelism
  • ici_autoregressive_parallelism: The number of shards for autoregressive parallelism
  • ici_tensor_parallelism: The number of shards for tensor parallelism
  • weight_dtype: Weight data type (e.g. bfloat16)
  • scan_layers: Scan layers boolean flag

Note: these flags are from MaxText config

Step 5: Send test request to JetStream MaxText server

cd ~
python JetStream/jetstream/tools/requester.py

The output will be similar to the following:

Sending request to: dns:///[::1]:9000
Prompt: Today is a good day
Response:  to be a fan

Step 6: Run benchmarks with JetStream MaxText server

Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}

Benchmarking Gemma-7b

Instructions

  • Download the ShareGPT dataset
  • Make sure to use the Gemma tokenizer (tokenizer.gemma) when running Gemma 7b.
  • Add --warmup-first flag for your 1st run to warmup the server
# Activate the python virtual environment we created in Step 2.
cd ~
source .env/bin/activate

# download dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# run benchmark with the downloaded dataset and the tokenizer in maxtext
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer /home/$USER/maxtext/assets/tokenizer.gemma \ 
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Benchmarking Llama2-*b

# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2). 

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Clean Up

# Clean up gcs buckets.
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up python virtual environment
rm -rf .env

JetStream Standalone Local Setup

Getting Started

Setup

pip install -r requirements.txt

Run local server & Testing

Use the following commands to run a server locally:

# Start a server
python -m jetstream.core.implementations.mock.server

# Test local mock server
python -m jetstream.tools.requester

# Load test local mock server
python -m jetstream.tools.load_tester

Test core modules

# Test JetStream core orchestrator
python -m jetstream.core.orchestrator_test

# Test JetStream core server library
python -m jetstream.core.server_test

# Test mock JetStream engine implementation
python -m jetstream.engine.mock_engine_test

# Test mock JetStream token utils
python -m jetstream.engine.utils_test

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

google-jetstream-0.2.0.tar.gz (33.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

google_jetstream-0.2.0-py3-none-any.whl (49.2 kB view details)

Uploaded Python 3

File details

Details for the file google-jetstream-0.2.0.tar.gz.

File metadata

  • Download URL: google-jetstream-0.2.0.tar.gz
  • Upload date:
  • Size: 33.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.0

File hashes

Hashes for google-jetstream-0.2.0.tar.gz
Algorithm Hash digest
SHA256 7f55481cc907138a847b11c3f900654f14e06c2c4887822b0748ed9ac26b9658
MD5 3b55c33a084dc6c8791bc95a0846d283
BLAKE2b-256 876148d65b025e1b923971d604cf75a1e2b8621622e0a73094a043bb83671ef6

See more details on using hashes here.

File details

Details for the file google_jetstream-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for google_jetstream-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8a4431eb924fe61834582c85dc29cd1f91f274c54ad05bdd35a8b1320178033a
MD5 7e8855bd3d3d4f92c3c7081d2984be28
BLAKE2b-256 ccec3cf0eb45ad341f21d4f92dfd227d83618aee2879eea9a570dee310725d6c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page