Skip to main content

JAX backend for SGL

Project description

SGL-JAX: High-Performance LLM Inference on JAX/TPU

SGL-JAX is a high-performance, JAX-based inference engine for Large Language Models (LLMs), specifically optimized for Google TPUs. It is engineered from the ground up to deliver exceptional throughput and low latency for the most demanding LLM serving workloads.

The engine integrates state-of-the-art techniques to maximize hardware utilization and serving efficiency, making it an ideal solution for deploying large-scale models in production with TPU.

Key Features

  • High-Throughput Continuous Batching: Implements a sophisticated scheduler that dynamically batches incoming requests, maximizing TPU utilization and overall throughput.
  • Optimized KV Cache with Radix Tree: Utilizes a Radix Tree for KV cache management (conceptually similar to PagedAttention), enabling memory-efficient prefix sharing between requests and significantly reducing computation for prompts with common prefixes.
  • FlashAttention Integration: Leverages a high-performance FlashAttention kernel for faster and more memory-efficient attention calculations, crucial for long sequences.
  • Tensor Parallelism: Natively supports tensor parallelism to distribute large models across multiple TPU devices, enabling inference for models that exceed the memory of a single accelerator.
  • OpenAI-Compatible API: Provides a drop-in replacement for the OpenAI API, allowing for seamless integration with a wide range of existing clients, SDKs, and tools (e.g., LangChain, LlamaIndex).
  • Native Qwen Support: Includes first-class, optimized support for the Qwen model family, including recent Mixture-of-Experts (MoE) variants.

Architecture Overview

SGL-JAX operates on a distributed architecture designed for scalability and performance:

  1. HTTP Server: The entry point for all requests, compatible with the OpenAI API standard.
  2. Scheduler: The core of the engine. It receives requests, manages prompts, and schedules token generation in batches. It intelligently groups requests to form optimal batches for the model executor.
  3. TP Worker (Tensor Parallel Worker): A set of distributed workers that host the model weights, distributed via tensor parallelism. They execute the forward pass for the model.
  4. Model Runner: Manages the actual JAX-based model execution, including the forward pass, attention computation, and KV cache operations.
  5. Radix Cache: A global, memory-efficient KV cache that is shared across all requests, enabling prefix reuse and reducing the memory footprint.

Quick Start

Follow these steps to get a model server up and running.

1. Installation

First, clone the repository and install the necessary dependencies. It is recommended to do this in a virtual environment.

git clone https://github.com/your-org/sgl-jax.git
cd sgl-jax/python
pip install -e .

2. Launch the Server

You can launch the OpenAI-compatible API server using the sgl_jax.launch_server module.

# Example: Launching a server for Qwen1.5-7B-Chat
python -m sgl_jax.launch_server \
    --model-path Qwen/Qwen1.5-7B-Chat \
    --tp-size 4 \
    --port 8000 \
    --host 0.0.0.0

Key Arguments:

  • --model-path: The path to the model on the Hugging Face Hub or a local directory.
  • --tp-size: The number of TPU devices to use for tensor parallelism.
  • --port: The port for the API server.
  • --host: The host address to bind the server to.

3. Send a Request

Once the server is running, you can interact with it using any OpenAI-compatible client, such as curl or the openai Python library.

Using curl:

curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "Qwen/Qwen1.5-7B-Chat",
    "messages": [
      {"role": "system", "content": "You are a helpful assistant."},
      {"role": "user", "content": "Hello, what is JAX?"}
    ],
    "max_tokens": 100,
    "temperature": 0.7
  }'

Using the openai Python client:

import openai

# Point the client to the local server
client = openai.OpenAI(
    api_key="your-api-key",  # Can be any string
    base_url="http://localhost:8000/v1"
)

response = client.chat.completions.create(
  model="Qwen/Qwen1.5-7B-Chat",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello, what is JAX?"}
  ]
)

print(response.choices[0].message.content)

Documentation

For more features and usage details, please read the documents in the docs directory.

Supported Models

SGL-JAX is designed for easy extension to new model architectures. It currently provides first-class, optimized support for:

  • Qwen
  • Qwen 3
  • Qwen 3 MoE

Performance and Benchmarking

Performance is a core focus of SGL-JAX. The engine is continuously benchmarked to ensure high throughput and low latency. For detailed performance evaluation and to run the benchmarks yourself, please see the scripts located in the benchmark/ and python/sgl_jax/ directories (e.g., bench_serving.py).

Testing

The project includes a comprehensive test suite to ensure correctness and stability. To run the full suite of tests:

cd test/srt
python run_suite.py

Contributing

Contributions are welcome! If you would like to contribute, please feel free to open an issue to discuss your ideas or submit a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sglang_jax-0.0.1.post1.tar.gz (272.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sglang_jax-0.0.1.post1-py3-none-any.whl (324.6 kB view details)

Uploaded Python 3

File details

Details for the file sglang_jax-0.0.1.post1.tar.gz.

File metadata

  • Download URL: sglang_jax-0.0.1.post1.tar.gz
  • Upload date:
  • Size: 272.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.13

File hashes

Hashes for sglang_jax-0.0.1.post1.tar.gz
Algorithm Hash digest
SHA256 b727b724768121271ba07a39bff4958349757a69aab8ffaed5809b1c14d55be9
MD5 f3c464c3c4888be78efd2675851f8839
BLAKE2b-256 66fac4dadd74a99d206a344f7a55a41bb5991544c9f91a73903ddc02f1b73a09

See more details on using hashes here.

File details

Details for the file sglang_jax-0.0.1.post1-py3-none-any.whl.

File metadata

File hashes

Hashes for sglang_jax-0.0.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 bdb12d8c3347c58ae2129399116780b41b0f407668f24fe7713230b3d30379b3
MD5 6c3659ebe2fe68151e5299ac91ec5d80
BLAKE2b-256 7ac99e61ee1c0129b6ae1b6643e0c0a48ab6f250c9866d120ef0f45d64504d9c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page