Skip to main content

No project description provided

Project description

🔬 IMPORTANT: EXPERIMENTAL AND NOT SUPPORTED 🔬

This is an exploratory repository provided for informational and learning purposes only. The code is not feature-complete and may not be stable.

⚠️ DO NOT USE IN A PRODUCTION ENVIRONMENT.

Develop on a TPU VM

Install vLLM-TPU:

Follow this guide to install vLLM from source.

Install tpu_commons:

cd ~
git clone https://github.com/vllm-project/tpu_commons.git
cd tpu_commons
pip install -r requirements.txt
pip install -e .

Setup pre-commit hooks

pip install pre-commit

# Linting, formatting and static type checking
pre-commit install --hook-type pre-commit --hook-type commit-msg

# You can manually run pre-commit with
pre-commit run --all-files

Run examples

Run JAX models

Run Llama 3.1 8B offline inference on 4 TPU chips:

HF_TOKEN=<huggingface_token> python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-8B \
    --tensor_parallel_size=4 \
    --max_model_len=1024

Run JAX models with local disaggregated serving

Run Llama 3.1 8B Instruct offline inference on 4 TPU chips in disaggregated mode:

PREFILL_SLICES=2 DECODE_SLICES=2 HF_TOKEN=<huggingface_token> \
python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Meta-Llama-3-8B-Instruct \
    --max_model_len=1024 \
    --max_num_seqs=8

Run JAX models with llm-d disaggregated serving

We simulate the llm-d scenario using a single TPU VM.

bash examples/disagg/run_disagg_servers.sh

Then follow the instructions output by the command to send requests.

Run JAX model with Ray-based multi-host serving

Run Llama 3.1 70B Instruct offline inference on 4 hosts (v6e-16) in interleaved mode:

  1. Deploy Ray cluster and containers:
~/tpu_commons/scripts/multihost/deploy_cluster.sh \
    -s ~/tpu_commons/scripts/multihost/run_cluster.sh \
    -d "<your_docker_image>" \
    -c "<path_on_remote_hosts_for_hf_cache>" \
    -t "<your_hugging_face_token>" \
    -H "<head_node_public_ip>" \
    -i "<head_node_private_ip>" \
    -W "<worker1_public_ip>,<worker2_public_ip>,<etc...>"
  1. On the head node, use sudo docker exec -it node /bin/bash to enter the container. And then execute:
HF_TOKEN=<huggingface_token> python /workspace/tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-70B  \
    --tensor_parallel_size=16  \
    --max_model_len=1024

Run vLLM Pytorch models on the JAX path

Run the vLLM's implementation of Llama 3.1 8B, which is in Pytorch. It is the same command as above with the extra env var MODEL_IMPL_TYPE=vllm:

export MODEL_IMPL_TYPE=vllm
export HF_TOKEN=<huggingface_token>
python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-8B \
    --tensor_parallel_size=4 \
    --max_model_len=1024

Run the vLLM Pytorch Qwen3-30B-A3B MoE model, use --enable-expert-parallel for expert parallelism, otherwise it defaults to tensor parallelism:

export MODEL_IMPL_TYPE=vllm
export HF_TOKEN=<huggingface_token>
python vllm/examples/offline_inference/basic/generate.py \
    --model=Qwen/Qwen3-30B-A3B \
    --tensor_parallel_size=4 \
    --max_model_len=1024 \
    --enable-expert-parallel

Run docker containers

Build and push docker image

This can be run on a CPU VM.

cd ~
git clone https://github.com/vllm-project/tpu_commons.git
cd tpu_commons

DOCKER_URI=<Specify a GCR URI>
# example:
# DOCKER_URI=gcr.io/cloud-nas-260507/ullm:$USER-test

docker build -f docker/Dockerfile -t $DOCKER_URI .
docker push $DOCKER_URI

Download docker image and run

Pull the docker image and run it:

DOCKER_URI=<the same URI used in docker build>
docker pull $DOCKER_URI
docker run \
  --rm \
  $DOCKER_URI \
  python /workspace/tpu_commons/examples/offline_inference.py \
  --model=meta-llama/Llama-3.1-8B \
  --tensor_parallel_size=4 \
  --max_model_len=1024 \

Relevant env

To switch different model implementations (default is flax_nnx):

MODEL_IMPL_TYPE=flax_nnx
MODEL_IMPL_TYPE=vllm

To run JAX models without precompiling:

SKIP_JAX_PRECOMPILE=1

To run JAX models with random initialized weights:

JAX_RANDOM_WEIGHTS=1

To run workloads on multi-host:

TPU_MULTIHOST_BACKEND=ray

Profiling

There are two ways to profile your workload:

Using PHASED_PROFILING_DIR

If you set the following environment variable:


PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>

we will automatically capture profiles during three phases of your workload (assuming they are encountered):

  1. Prefill-heavy (the quotient of prefill / total scheduled tokens for the given batch is => 0.9)
  2. Decode-heavy (the quotient of prefill / total scheduled tokens for the given batch is <= 0.2)
  3. Mixed (the quotient of prefill / total scheduled tokens for the given batch is between 0.4 and 0.6)

To aid in your analysis, we will also log the batch composition for the profiled batches.

Using USE_JAX_PROFILER_SERVER

If you set the following environment variable:


USE_JAX_PROFILER_SERVER=True

you can instead manually decide when to capture a profile and for how long, which can helpful if your workload (e.g. E2E benchmarking) is large and taking a profile of the entire workload (i.e. using the above method) will generate a massive tracing file.

You can additionally set the desired profiling port (default is 9999):


JAX_PROFILER_SERVER_PORT=XXXX

In order to use this approach, you can do the following:

  1. Run your typical vllm serve or offline_inference command (making sure to set USE_JAX_PROFILER_SERVER=True)

  2. Run your benchmarking command (python benchmark_serving.py...)

  3. Once the warmup has completed and your benchmark is running, start a new tensorboard instance with your logdir set to the desired output location of your profiles (e.g. tensorboard --logdir=profiles/llama3-mmlu/)

  4. Open the tensorboard instance and navigate to the profile page (e.g. http://localhost:6006/#profile)

  5. Click Capture Profile and, in the Profile Service URL(s) or TPU name box, enter localhost:XXXX where XXXX is your JAX_PROFILER_SERVER_PORT (default is 9999)

  6. Enter the desired amount of time (in ms) you'd like to capture the profile for and then click Capture. If everything goes smoothly, you should see a success message, and your logdir should be populated.

How to run an End-To-End (E2E) benchmark?

In order to run an E2E benchmark test, which will spin up a vLLM server with Llama 3.1 8B and run a single request from the MLPerf dataset against it, you can run the following command locally:


BUILDKITE_COMMIT=0f199f1 .buildkite/scripts/run_in_docker.sh bash /workspace/tpu_commons/tests/e2e/benchmarking/mlperf.sh

While this will run the code in a Docker image, you can also run the bare tests/e2e/benchmarking/mlperf.sh script itself, being sure to pass the proper args for your machine.

You might need to run the benchmark client twice to make sure all compilations are cached server-side.

Quantization

Overview

Currently, we support overall model weight/activation quantization through the Qwix framework.

To enable quantization, you can do one of the following:

Using a quantization config YAML

Simply pass the name of a quantization config found inside the quantization config directory (tpu_commons/models/jax/utils/quantization/configs/), for example:


... --additional_config='{"quantization": "int8_default.yaml"}'

Using a quantization config JSON

Alternatively, you can pass the explicit quantization configuration as JSON string, where each entry in rules corresponds to a Qwix rule (see below):


{ "qwix": { "rules": [{ "module_path": ".*", "weight_qtype": "int8", "act_qtype": "int8" }]}}

Creating your own quantization config YAML

To create your own quantization config YAML file:

  1. Add a new file to the quantization config directory (tpu_commons/models/jax/utils/quantization/configs/)
  2. For Qwix quantization, add a new entry to the file as follows:

qwix:
  rules:
    # NOTE: each entry corresponds to a qwix.QuantizationRule
    - module_path: '.*'
      weight_qtype: 'int8'
      act_qtype: 'int8'

where each entry under rules corresponds to a qwix.QuantizationRule. To learn more about Qwix and defining Qwix rules, please see the relevant docs here.

  1. To use the config, simply pass the name of the file you created in the --additional_config, e.g.:

... --additional_config='{"quantization": "YOUR_FILE_NAME_HERE.yaml"}'

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

test_ylang-0.9.0.tar.gz (94.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

test_ylang-0.9.0-py3-none-any.whl (115.7 kB view details)

Uploaded Python 3

File details

Details for the file test_ylang-0.9.0.tar.gz.

File metadata

  • Download URL: test_ylang-0.9.0.tar.gz
  • Upload date:
  • Size: 94.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for test_ylang-0.9.0.tar.gz
Algorithm Hash digest
SHA256 a6c9f024248bbfd0fa7b14ee560dfd2ad59b24df77b3bb2a6c3fcf2c8f199b1e
MD5 2f792d258d4f083939aa0675518b0f61
BLAKE2b-256 1914090b9d51b38df7e64b238cec7473a426d593fa873ea4037d3369bc505757

See more details on using hashes here.

Provenance

The following attestation bundles were made for test_ylang-0.9.0.tar.gz:

Publisher: release.yml on ylangtsou/tpu_commons

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file test_ylang-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: test_ylang-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 115.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for test_ylang-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6d9dfb24218537b7edac26978b2cd94d937ca928c1701d2a4bcdf3909f6cca74
MD5 6e2b32e14d8b7425ceefa2cccffc9fbb
BLAKE2b-256 e999a22e4ea0d0f1a7a302d672919f74f68be7264be27dcd3a78867d03f10eb1

See more details on using hashes here.

Provenance

The following attestation bundles were made for test_ylang-0.9.0-py3-none-any.whl:

Publisher: release.yml on ylangtsou/tpu_commons

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page