Skip to main content

No project description provided

Project description

🔬 IMPORTANT: EXPERIMENTAL AND NOT SUPPORTED 🔬

This is an exploratory repository provided for informational and learning purposes only. The code is not feature-complete and may not be stable.

⚠️ DO NOT USE IN A PRODUCTION ENVIRONMENT.

Develop on a TPU VM

Install vLLM-TPU:

Follow this guide to install vLLM from source.

Install tpu_commons:

cd ~
git clone https://github.com/vllm-project/tpu_commons.git
cd tpu_commons
pip install -r requirements.txt
pip install -e .

Setup pre-commit hooks

pip install pre-commit

# Linting, formatting and static type checking
pre-commit install --hook-type pre-commit --hook-type commit-msg

# You can manually run pre-commit with
pre-commit run --all-files

Run examples

Run JAX models

Run Llama 3.1 8B offline inference on 4 TPU chips:

HF_TOKEN=<huggingface_token> python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-8B \
    --tensor_parallel_size=4 \
    --max_model_len=1024

Run JAX models with local disaggregated serving

Run Llama 3.1 8B Instruct offline inference on 4 TPU chips in disaggregated mode:

PREFILL_SLICES=2 DECODE_SLICES=2 HF_TOKEN=<huggingface_token> \
python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Meta-Llama-3-8B-Instruct \
    --max_model_len=1024 \
    --max_num_seqs=8

Run JAX models with llm-d disaggregated serving

We simulate the llm-d scenario using a single TPU VM.

bash examples/disagg/run_disagg_servers.sh

Then follow the instructions output by the command to send requests.

Run JAX model with Ray-based multi-host serving

Run Llama 3.1 70B Instruct offline inference on 4 hosts (v6e-16) in interleaved mode:

  1. Deploy Ray cluster and containers:
~/tpu_commons/scripts/multihost/deploy_cluster.sh \
    -s ~/tpu_commons/scripts/multihost/run_cluster.sh \
    -d "<your_docker_image>" \
    -c "<path_on_remote_hosts_for_hf_cache>" \
    -t "<your_hugging_face_token>" \
    -H "<head_node_public_ip>" \
    -i "<head_node_private_ip>" \
    -W "<worker1_public_ip>,<worker2_public_ip>,<etc...>"
  1. On the head node, use sudo docker exec -it node /bin/bash to enter the container. And then execute:
HF_TOKEN=<huggingface_token> python /workspace/tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-70B  \
    --tensor_parallel_size=16  \
    --max_model_len=1024

Run vLLM Pytorch models on the JAX path

Run the vLLM's implementation of Llama 3.1 8B, which is in Pytorch. It is the same command as above with the extra env var MODEL_IMPL_TYPE=vllm:

export MODEL_IMPL_TYPE=vllm
export HF_TOKEN=<huggingface_token>
python tpu_commons/examples/offline_inference.py \
    --model=meta-llama/Llama-3.1-8B \
    --tensor_parallel_size=4 \
    --max_model_len=1024

Run the vLLM Pytorch Qwen3-30B-A3B MoE model, use --enable-expert-parallel for expert parallelism, otherwise it defaults to tensor parallelism:

export MODEL_IMPL_TYPE=vllm
export HF_TOKEN=<huggingface_token>
python vllm/examples/offline_inference/basic/generate.py \
    --model=Qwen/Qwen3-30B-A3B \
    --tensor_parallel_size=4 \
    --max_model_len=1024 \
    --enable-expert-parallel

Run docker containers

Build and push docker image

This can be run on a CPU VM.

cd ~
git clone https://github.com/vllm-project/tpu_commons.git
cd tpu_commons

DOCKER_URI=<Specify a GCR URI>
# example:
# DOCKER_URI=gcr.io/cloud-nas-260507/ullm:$USER-test

docker build -f docker/Dockerfile -t $DOCKER_URI .
docker push $DOCKER_URI

Download docker image and run

Pull the docker image and run it:

DOCKER_URI=<the same URI used in docker build>
docker pull $DOCKER_URI
docker run \
  --rm \
  $DOCKER_URI \
  python /workspace/tpu_commons/examples/offline_inference.py \
  --model=meta-llama/Llama-3.1-8B \
  --tensor_parallel_size=4 \
  --max_model_len=1024 \

Relevant env

To switch different model implementations (default is flax_nnx):

MODEL_IMPL_TYPE=flax_nnx
MODEL_IMPL_TYPE=vllm

To run JAX models without precompiling:

SKIP_JAX_PRECOMPILE=1

To run JAX models with random initialized weights:

JAX_RANDOM_WEIGHTS=1

To run workloads on multi-host:

TPU_MULTIHOST_BACKEND=ray

Profiling

There are two ways to profile your workload:

Using PHASED_PROFILING_DIR

If you set the following environment variable:


PHASED_PROFILING_DIR=<DESIRED PROFILING OUTPUT DIR>

we will automatically capture profiles during three phases of your workload (assuming they are encountered):

  1. Prefill-heavy (the quotient of prefill / total scheduled tokens for the given batch is => 0.9)
  2. Decode-heavy (the quotient of prefill / total scheduled tokens for the given batch is <= 0.2)
  3. Mixed (the quotient of prefill / total scheduled tokens for the given batch is between 0.4 and 0.6)

To aid in your analysis, we will also log the batch composition for the profiled batches.

Using USE_JAX_PROFILER_SERVER

If you set the following environment variable:


USE_JAX_PROFILER_SERVER=True

you can instead manually decide when to capture a profile and for how long, which can helpful if your workload (e.g. E2E benchmarking) is large and taking a profile of the entire workload (i.e. using the above method) will generate a massive tracing file.

You can additionally set the desired profiling port (default is 9999):


JAX_PROFILER_SERVER_PORT=XXXX

In order to use this approach, you can do the following:

  1. Run your typical vllm serve or offline_inference command (making sure to set USE_JAX_PROFILER_SERVER=True)

  2. Run your benchmarking command (python benchmark_serving.py...)

  3. Once the warmup has completed and your benchmark is running, start a new tensorboard instance with your logdir set to the desired output location of your profiles (e.g. tensorboard --logdir=profiles/llama3-mmlu/)

  4. Open the tensorboard instance and navigate to the profile page (e.g. http://localhost:6006/#profile)

  5. Click Capture Profile and, in the Profile Service URL(s) or TPU name box, enter localhost:XXXX where XXXX is your JAX_PROFILER_SERVER_PORT (default is 9999)

  6. Enter the desired amount of time (in ms) you'd like to capture the profile for and then click Capture. If everything goes smoothly, you should see a success message, and your logdir should be populated.

How to run an End-To-End (E2E) benchmark?

In order to run an E2E benchmark test, which will spin up a vLLM server with Llama 3.1 8B and run a single request from the MLPerf dataset against it, you can run the following command locally:


BUILDKITE_COMMIT=0f199f1 .buildkite/scripts/run_in_docker.sh bash /workspace/tpu_commons/tests/e2e/benchmarking/mlperf.sh

While this will run the code in a Docker image, you can also run the bare tests/e2e/benchmarking/mlperf.sh script itself, being sure to pass the proper args for your machine.

You might need to run the benchmark client twice to make sure all compilations are cached server-side.

Quantization

Overview

Currently, we support overall model weight/activation quantization through the Qwix framework.

To enable quantization, you can do one of the following:

Using a quantization config YAML

Simply pass the name of a quantization config found inside the quantization config directory (tpu_commons/models/jax/utils/quantization/configs/), for example:


... --additional_config='{"quantization": "int8_default.yaml"}'

Using a quantization config JSON

Alternatively, you can pass the explicit quantization configuration as JSON string, where each entry in rules corresponds to a Qwix rule (see below):


{ "qwix": { "rules": [{ "module_path": ".*", "weight_qtype": "int8", "act_qtype": "int8" }]}}

Creating your own quantization config YAML

To create your own quantization config YAML file:

  1. Add a new file to the quantization config directory (tpu_commons/models/jax/utils/quantization/configs/)
  2. For Qwix quantization, add a new entry to the file as follows:

qwix:
  rules:
    # NOTE: each entry corresponds to a qwix.QuantizationRule
    - module_path: '.*'
      weight_qtype: 'int8'
      act_qtype: 'int8'

where each entry under rules corresponds to a qwix.QuantizationRule. To learn more about Qwix and defining Qwix rules, please see the relevant docs here.

  1. To use the config, simply pass the name of the file you created in the --additional_config, e.g.:

... --additional_config='{"quantization": "YOUR_FILE_NAME_HERE.yaml"}'

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

test_ylang-0.11.0.tar.gz (94.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

test_ylang-0.11.0-py3-none-any.whl (115.7 kB view details)

Uploaded Python 3

File details

Details for the file test_ylang-0.11.0.tar.gz.

File metadata

  • Download URL: test_ylang-0.11.0.tar.gz
  • Upload date:
  • Size: 94.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for test_ylang-0.11.0.tar.gz
Algorithm Hash digest
SHA256 3958902bdd03c633204a308dce0516e79a6792b59f544eca75b8d4a9520e65ea
MD5 73759e6cc51180f998b290232eaf29a6
BLAKE2b-256 b07ec8be2aac854b715b1814571040a403c666f6b4d63be3adde3645a0942346

See more details on using hashes here.

Provenance

The following attestation bundles were made for test_ylang-0.11.0.tar.gz:

Publisher: release.yml on ylangtsou/tpu_commons

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file test_ylang-0.11.0-py3-none-any.whl.

File metadata

  • Download URL: test_ylang-0.11.0-py3-none-any.whl
  • Upload date:
  • Size: 115.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for test_ylang-0.11.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6adb649ed8487c3e60348874e6256f63059a7711c2caaac18be033d2a359f4e1
MD5 a87c1dd23e4a6393dcf0ed1830c22f78
BLAKE2b-256 0e31de8da6a69a60dca9efab8415d1ce03b1b31296d7551e936373ad6c03787e

See more details on using hashes here.

Provenance

The following attestation bundles were made for test_ylang-0.11.0-py3-none-any.whl:

Publisher: release.yml on ylangtsou/tpu_commons

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page