Dynamic batching for Qwen3-TTS with CUDA graph acceleration
Project description
faster-qwen3-tts-batch
Dynamic batching for Qwen3-TTS with CUDA graph acceleration. Extends faster-qwen3-tts to support batched inference, achieving 2-4x throughput improvement via shared weight computation and batched CUDA graphs.
Features
- Batched CUDA Graphs: Predictor and talker decode steps captured as batched CUDA graphs for minimal kernel launch overhead
- Batched Prefill: Single forward pass for all B samples with left-padding alignment (vs B separate prefills)
- Continuous Batching: Mid-generation slot replacement — when a sample finishes (EOS), its slot is immediately filled with the next pending request
- Async Scheduler:
BatchSchedulerwith dedicated GPU thread, smart batching with configurablemax_wait_ms, and asyncinfer()API - Two Execution Modes:
continuous=False(default): Batch-and-wait — collect batch, run to completion, optimal for consumer GPUscontinuous=True: Slot replacement — fill freed slots mid-generation, optimal for high-end GPUs (H100/A100)
Installation
pip install -e .
Requires faster-qwen3-tts >= 0.2.4 and a CUDA-capable GPU.
Quick Start
Direct Batched Generation
from faster_qwen3_tts_batch import BatchedFasterQwen3TTS
model = BatchedFasterQwen3TTS.from_pretrained(
"path/to/Qwen3-TTS-12Hz-0.6B-Base",
max_batch_size=4,
max_seq_len=512,
)
model.warmup(prefill_len=100)
requests = [
{"text": "你好世界", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
{"text": "今天天气不错", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
]
outputs = model.generate_voice_clone_batch(requests, max_new_tokens=500)
for audio_arrays, sr in outputs:
# audio_arrays[0] is numpy array, sr is sample rate (24000)
...
Async Scheduler (for serving)
import asyncio
from faster_qwen3_tts_batch import BatchedFasterQwen3TTS, BatchScheduler
model = BatchedFasterQwen3TTS.from_pretrained("path/to/model", max_batch_size=4)
model.warmup()
scheduler = BatchScheduler(
model,
max_batch_size=4,
max_wait_ms=50, # wait up to 50ms to collect a batch
continuous=False, # batch-and-wait mode (default)
gen_kwargs={"max_new_tokens": 500, "repetition_penalty": 1.1},
)
async def serve():
await scheduler.start()
# Concurrent requests are automatically batched
audio, sr = await scheduler.infer(
text="你好世界",
language="Auto",
ref_audio="ref.wav",
)
await scheduler.stop()
asyncio.run(serve())
Architecture
┌──────────────┐
│ infer() │ ← async API (multiple callers)
│ (event loop)│
└──────┬───────┘
│ queue.Queue (thread-safe)
┌──────▼───────┐
│ GPU Thread │ ← dedicated engine loop
│ │
│ collect │ wait up to max_wait_ms
│ batch │ or until max_batch_size
│ │
┌──────┴──────────────┴──────┐
│ │
┌────────▼─────────┐ ┌────────────▼───────────┐
│ batch-and-wait │ │ continuous batching │
│ (continuous=False)│ │ (continuous=True) │
│ │ │ │
│ generate_batch() │ │ continuous_generate() │
│ → decode audio │ │ → slot replacement │
│ → deliver all │ │ → decode after done │
└──────────────────┘ └────────────────────────┘
Key Components
| File | Description |
|---|---|
model.py |
BatchedFasterQwen3TTS — main API, wraps base model with batched graphs |
scheduler.py |
BatchScheduler — async scheduler with dedicated GPU thread |
continuous_generate.py |
Continuous batched generation with per-slot state tracking |
batched_generate.py |
Batch-and-wait generation (all samples run to completion) |
batched_talker_graph.py |
Batched CUDA graph for talker decode with slot replacement support |
batched_predictor_graph.py |
Batched CUDA graph for code predictor |
batched_sampling.py |
Batched top-k/top-p sampling with per-slot EOS suppression |
Benchmark Results (RTX 3060, 0.6B model, 8 requests)
Sequential (8x B=1): ~9.0s throughput=2.9x realtime
Batch-and-Wait (2x B=4): ~5.4s throughput=5.2x realtime
Continuous Batching (B=4): ~10.6s throughput=3.8x realtime (4 replacements)
Key finding: On consumer GPUs (RTX 3060/4090), single-sample eager prefill for slot replacement (~500ms) costs more than it saves. Batch-and-wait is the optimal strategy for consumer GPUs. Continuous batching benefits high-end GPUs (H100/A100) where prefill latency is ~50ms.
Development History
Phase 1: Batched CUDA Graphs + Scheduler (v0.1.0)
The initial version established the core infrastructure:
BatchedPredictorGraph: CUDA-graphed batched code predictor — takes[B, 2, H]input (past_hidden + last_token_embed), produces[B, 15]codebook tokens in a single graph replayBatchedTalkerGraph: CUDA-graphed batched talker decode — left-padding alignment so all samples share a singlecache_positionscalar, with per-sample differences handled byattention_maskandrope_deltasbatched_fast_generate(): Batched autoregressive loop with shared prefill, CUDA-graphed decode, and batched samplingBatchScheduler: Async scheduler that collects requests with smart waiting (max_batch_size or max_wait_ms timeout), then dispatches to GPU
Key design decisions:
- Left-padding alignment: All samples padded to
max_seq_lenin the batch, enabling a single sharedcache_positionfor CUDA graph capture - StaticCache: Required for CUDA graphs (no dynamic memory allocation during replay)
- Per-batch rope_deltas: Each sample carries its own rope_delta to compensate for padding differences
Phase 2: Batched Prefill (PR #2)
Replaced B separate talker prefill forward passes with a single batched forward:
- Before: Loop
Btimes callingtalker.forward()individually, each producing a separate DynamicCache - After: Left-pad all inputs to
max_seq_len, singletalker.forward()call, one DynamicCache for the whole batch
Added prefill_kv_batched() to BatchedTalkerGraph for direct batched-cache-to-StaticCache copy (vs the original prefill_kv() which left-pads individual caches).
Also introduced set_generation_state() with padding_in_rope_deltas flag — when rope_deltas come from a batched prefill (with attention_mask), they already account for padding, so no additional offset is needed.
Phase 3: Continuous Batching (v0.2.0, current)
The most architecturally significant change — enabling mid-generation slot replacement:
On-the-fly attention mask computation (batched_talker_graph.py):
- Replaced the precomputed
attn_mask_table[position]lookup with per-step_compute_and_set_mask()using HuggingFace'screate_causal_mask/create_sliding_window_causal_mask - This eliminated the need to rebuild the entire mask table when a slot is replaced — just update
_attention_mask_2d[slot_idx]and recompute - Trade-off: Slightly more compute per decode step, but enables slot replacement without O(max_seq_len) mask rebuilds
Slot replacement primitives (batched_talker_graph.py):
replace_slot_kv(): Zero out a slot's KV cache in StaticCache, then write the new sample's left-padded KV data at[current_pos - prefill_len, current_pos)update_slot_state(): Update_attention_mask_2dandrope_deltasfor the replaced slot- Key discovery: HuggingFace
StaticCachelayers uselayer.keys/layer.values(notlayer.key_cache/layer.value_cache)
Per-slot generation state (continuous_generate.py):
gen_steps(per-slot): Controls trailing_text_hidden injection timing — each slot independently indexes into its text hidden sequenceslot_steps(per-slot): Controls min_new_tokens EOS suppression — replacement samples start from 0- Dynamic
trailing_text_paddedbuffer with auto-expansion when replacement samples have longer text id()-based completed tag tracking to avoid unhashable_PendingRequestdataclass in sets
Deferred audio decode (scheduler.py):
on_slot_donecallback collects raw(tag, codec_ids.clone(), timing)tuples during generation- Audio decode (
speech_tokenizer.decode()) runs aftercontinuous_batched_generate()returns - This avoids CUDA state conflicts between the speech tokenizer and CUDA-graphed generation loop
Dual-mode scheduler:
continuous=False(default): Delegates togenerate_voice_clone_batch()— proven faster on consumer GPUscontinuous=True: Usescontinuous_batched_generate()withtry_get_replacementcallback- Dedicated GPU thread with
queue.Queuefor thread-safe async→sync bridging
Performance insight: Continuous batching is not universally beneficial. On RTX 3060, single-sample eager prefill for slot replacement (~500ms) pauses all active slots, costing more time than it saves. The continuous flag defaults to False to reflect this finding.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file faster_qwen3_tts_batch-0.2.0.tar.gz.
File metadata
- Download URL: faster_qwen3_tts_batch-0.2.0.tar.gz
- Upload date:
- Size: 28.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
99e495a4dfb918351ad896c1c9fb0a52e139ceb398de6a33d41f0398609f0814
|
|
| MD5 |
2953d8600765d039d76733f0b4ba0c76
|
|
| BLAKE2b-256 |
75f6166b6b43b9e6d537656932f0742abb748f02d897f52b9b74437bdd739af8
|
File details
Details for the file faster_qwen3_tts_batch-0.2.0-py3-none-any.whl.
File metadata
- Download URL: faster_qwen3_tts_batch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 30.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a365132809b42f8508a76f8b77906149f97e8fc57f241bb3b5a2729ecddffc29
|
|
| MD5 |
74ccc98f504dcf90f790c9d5d7ae40e6
|
|
| BLAKE2b-256 |
4433297dc7eff97edb19b80cc99274011600d6b3dc81bc6c0d2ec07597505b79
|