Skip to main content

JAX backend for SGL

Project description

SGL-JAX: High-Performance LLM Inference on JAX/TPU

SGL-JAX is a high-performance, JAX-based inference engine for Large Language Models (LLMs), specifically optimized for Google TPUs. It is engineered from the ground up to deliver exceptional throughput and low latency for the most demanding LLM serving workloads.

The engine incorporates state-of-the-art techniques to maximize hardware utilization and serving efficiency, making it ideal for deploying large-scale models in production on TPUs.

Pypi License

Key Features

  • High-Throughput Continuous Batching: Implements a sophisticated scheduler that dynamically batches incoming requests, maximizing TPU utilization and overall throughput.
  • Optimized KV Cache with Radix Tree: Utilizes a Radix Tree for KV cache management (conceptually similar to PagedAttention), enabling memory-efficient prefix sharing between requests and significantly reducing computation for prompts with common prefixes.
  • FlashAttention Integration: Leverages a high-performance FlashAttention kernel for faster and more memory-efficient attention calculations, crucial for long sequences.
  • Tensor Parallelism: Natively supports tensor parallelism to distribute large models across multiple TPU devices, enabling inference for models that exceed the memory of a single accelerator.
  • OpenAI-Compatible API: Provides a drop-in replacement for the OpenAI API, allowing for seamless integration with a wide range of existing clients, SDKs, and tools (e.g., LangChain, LlamaIndex).
  • Native Qwen Support: Includes first-class, optimized support for the Qwen model family, including recent Mixture-of-Experts (MoE) variants.

Architecture Overview

SGL-JAX operates on a distributed architecture designed for scalability and performance:

  1. HTTP Server: The entry point for all requests, compatible with the OpenAI API standard.
  2. Scheduler: The core of the engine. It receives requests, manages prompts, and schedules token generation in batches. It intelligently groups requests to form optimal batches for the model executor.
  3. TP Worker (Tensor Parallel Worker): A set of distributed workers that host the model weights, distributed via tensor parallelism. They execute the forward pass for the model.
  4. Model Runner: Manages the actual JAX-based model execution, including the forward pass, attention computation, and KV cache operations.
  5. Radix Cache: A global, memory-efficient KV cache that is shared across all requests, enabling prefix reuse and reducing the memory footprint.

Getting Started

Documentation

For more features and usage details, please read the documents in the docs directory.

Supported Models

SGL-JAX is designed for easy extension to new model architectures. It currently provides first-class, optimized support for:

  • Qwen
  • Qwen 3
  • Qwen 3 MoE

Performance and Benchmarking

For detailed performance evaluation and to run the benchmarks yourself, please see the scripts located in the benchmark/ and python/sgl_jax/ directories (e.g., bench_serving.py).

Testing

The project includes a comprehensive test suite to ensure correctness and stability. To run the full suite of tests:

cd test/srt
python run_suite.py

Contributing

Contributions are welcome! If you would like to contribute, please feel free to open an issue to discuss your ideas or submit a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sglang_jax-0.0.1.post2.tar.gz (329.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sglang_jax-0.0.1.post2-py3-none-any.whl (394.5 kB view details)

Uploaded Python 3

File details

Details for the file sglang_jax-0.0.1.post2.tar.gz.

File metadata

  • Download URL: sglang_jax-0.0.1.post2.tar.gz
  • Upload date:
  • Size: 329.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for sglang_jax-0.0.1.post2.tar.gz
Algorithm Hash digest
SHA256 d5476d94271045d990203902083a8d1d6dff4b4492f8c6514e82120e2085ae80
MD5 d455ce44a13b48789d6960cb5a2be0d4
BLAKE2b-256 a0c761801ffc80075eb7306c7bf3a8d1403a71400d058465725227e4e59ab0fb

See more details on using hashes here.

File details

Details for the file sglang_jax-0.0.1.post2-py3-none-any.whl.

File metadata

File hashes

Hashes for sglang_jax-0.0.1.post2-py3-none-any.whl
Algorithm Hash digest
SHA256 d2fde0606d9a6a2eeaab04408a12a9da3d756787d7abedc38d25ff297d1da848
MD5 3a5d7628cc6226f4597bfd2afd7b5a05
BLAKE2b-256 884bf42e9647ba04625f6e88e7a33514ada6bd220a07b75df0a5dfadaab67f1e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page