Skip to main content

JAX backend for SGL

Project description

SGL-JAX: High-Performance LLM Inference on JAX/TPU

SGL-JAX is a high-performance, JAX-based inference engine for Large Language Models (LLMs), specifically optimized for Google TPUs. It is engineered from the ground up to deliver exceptional throughput and low latency for the most demanding LLM serving workloads.

The engine incorporates state-of-the-art techniques to maximize hardware utilization and serving efficiency, making it ideal for deploying large-scale models in production on TPUs.

Pypi License

Key Features

  • High-Throughput Continuous Batching: Implements a sophisticated scheduler that dynamically batches incoming requests, maximizing TPU utilization and overall throughput.
  • Optimized KV Cache with Radix Tree: Utilizes a Radix Tree for KV cache management (conceptually similar to PagedAttention), enabling memory-efficient prefix sharing between requests and significantly reducing computation for prompts with common prefixes.
  • FlashAttention Integration: Leverages a high-performance FlashAttention kernel for faster and more memory-efficient attention calculations, crucial for long sequences.
  • Tensor Parallelism: Natively supports tensor parallelism to distribute large models across multiple TPU devices, enabling inference for models that exceed the memory of a single accelerator.
  • OpenAI-Compatible API: Provides a drop-in replacement for the OpenAI API, allowing for seamless integration with a wide range of existing clients, SDKs, and tools (e.g., LangChain, LlamaIndex).
  • Native Qwen Support: Includes first-class, optimized support for the Qwen model family, including recent Mixture-of-Experts (MoE) variants.

Architecture Overview

SGL-JAX operates on a distributed architecture designed for scalability and performance:

  1. HTTP Server: The entry point for all requests, compatible with the OpenAI API standard.
  2. Scheduler: The core of the engine. It receives requests, manages prompts, and schedules token generation in batches. It intelligently groups requests to form optimal batches for the model executor.
  3. TP Worker (Tensor Parallel Worker): A set of distributed workers that host the model weights, distributed via tensor parallelism. They execute the forward pass for the model.
  4. Model Runner: Manages the actual JAX-based model execution, including the forward pass, attention computation, and KV cache operations.
  5. Radix Cache: A global, memory-efficient KV cache that is shared across all requests, enabling prefix reuse and reducing the memory footprint.

Getting Started

Documentation

For more features and usage details, please read the documents in the docs directory.

Supported Models

SGL-JAX is designed for easy extension to new model architectures. It currently provides first-class, optimized support for:

  • Qwen
  • Qwen 3
  • Qwen 3 MoE

Performance and Benchmarking

For detailed performance evaluation and to run the benchmarks yourself, please see the scripts located in the benchmark/ and python/sgl_jax/ directories (e.g., bench_serving.py).

Testing

The project includes a comprehensive test suite to ensure correctness and stability. To run the full suite of tests:

cd test/srt
python run_suite.py

Contributing

Contributions are welcome! If you would like to contribute, please feel free to open an issue to discuss your ideas or submit a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sglang_jax-0.0.2.tar.gz (329.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

sglang_jax-0.0.2-py3-none-any.whl (394.5 kB view details)

Uploaded Python 3

File details

Details for the file sglang_jax-0.0.2.tar.gz.

File metadata

  • Download URL: sglang_jax-0.0.2.tar.gz
  • Upload date:
  • Size: 329.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for sglang_jax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 70e38d3513797ea208c4d33e957b6950b47ff069492b8099b65152a70e40416e
MD5 9a314a465eed3b6ebffbd1791a51ae75
BLAKE2b-256 c7941b525f5936c296b877ff7fe4c8c8b5f9088a73dd0197b624105a6b9f68ce

See more details on using hashes here.

File details

Details for the file sglang_jax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: sglang_jax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 394.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for sglang_jax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 1ba0d3f7801a47c3961e50995a6ee832d8d094c5595d9ad865f4f9ebbe3fdfcb
MD5 3c9972e58a06cdcd83aeee6bea20d79f
BLAKE2b-256 4619053f86f545869badfb5164df6e314f5cfb87dd22ec89b71efe093a85fbf0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page