Skip to main content

Dynamic batching for Qwen3-TTS with CUDA graph acceleration

Project description

faster-qwen3-tts-batch

Dynamic batching for Qwen3-TTS with CUDA graph acceleration. Extends faster-qwen3-tts to support batched inference, achieving 2-4x throughput improvement via shared weight computation and batched CUDA graphs.

Features

  • Batched CUDA Graphs: Predictor and talker decode steps captured as batched CUDA graphs for minimal kernel launch overhead
  • Batched Prefill: Single forward pass for all B samples with left-padding alignment (vs B separate prefills)
  • Continuous Batching: Mid-generation slot replacement — when a sample finishes (EOS), its slot is immediately filled with the next pending request
  • Async Scheduler: BatchScheduler with dedicated GPU thread, smart batching with configurable max_wait_ms, and async infer() API
  • Two Execution Modes:
    • continuous=False (default): Batch-and-wait — collect batch, run to completion, optimal for consumer GPUs
    • continuous=True: Slot replacement — fill freed slots mid-generation, optimal for high-end GPUs (H100/A100)

Installation

pip install -e .

Requires faster-qwen3-tts >= 0.2.4 and a CUDA-capable GPU.

Quick Start

Direct Batched Generation

from faster_qwen3_tts_batch import BatchedFasterQwen3TTS

model = BatchedFasterQwen3TTS.from_pretrained(
    "path/to/Qwen3-TTS-12Hz-0.6B-Base",
    max_batch_size=4,
    max_seq_len=512,
)
model.warmup(prefill_len=100)

requests = [
    {"text": "你好世界", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
    {"text": "今天天气不错", "language": "Auto", "ref_audio": "ref.wav", "ref_text": ""},
]
outputs = model.generate_voice_clone_batch(requests, max_new_tokens=500)

for audio_arrays, sr in outputs:
    # audio_arrays[0] is numpy array, sr is sample rate (24000)
    ...

Async Scheduler (for serving)

import asyncio
from faster_qwen3_tts_batch import BatchedFasterQwen3TTS, BatchScheduler

model = BatchedFasterQwen3TTS.from_pretrained("path/to/model", max_batch_size=4)
model.warmup()

scheduler = BatchScheduler(
    model,
    max_batch_size=4,
    max_wait_ms=50,         # wait up to 50ms to collect a batch
    continuous=False,        # batch-and-wait mode (default)
    gen_kwargs={"max_new_tokens": 500, "repetition_penalty": 1.1},
)

async def serve():
    await scheduler.start()

    # Concurrent requests are automatically batched
    audio, sr = await scheduler.infer(
        text="你好世界",
        language="Auto",
        ref_audio="ref.wav",
    )

    await scheduler.stop()

asyncio.run(serve())

Architecture

                     ┌──────────────┐
                     │  infer()     │  ← async API (multiple callers)
                     │  (event loop)│
                     └──────┬───────┘
                            │ queue.Queue (thread-safe)
                     ┌──────▼───────┐
                     │  GPU Thread  │  ← dedicated engine loop
                     │              │
                     │  collect     │  wait up to max_wait_ms
                     │  batch       │  or until max_batch_size
                     │              │
              ┌──────┴──────────────┴──────┐
              │                            │
     ┌────────▼─────────┐     ┌────────────▼───────────┐
     │ batch-and-wait   │     │ continuous batching     │
     │ (continuous=False)│     │ (continuous=True)       │
     │                  │     │                        │
     │ generate_batch() │     │ continuous_generate()  │
     │ → decode audio   │     │ → slot replacement     │
     │ → deliver all    │     │ → decode after done    │
     └──────────────────┘     └────────────────────────┘

Key Components

File Description
model.py BatchedFasterQwen3TTS — main API, wraps base model with batched graphs
scheduler.py BatchScheduler — async scheduler with dedicated GPU thread
continuous_generate.py Continuous batched generation with per-slot state tracking
batched_generate.py Batch-and-wait generation (all samples run to completion)
batched_talker_graph.py Batched CUDA graph for talker decode with slot replacement support
batched_predictor_graph.py Batched CUDA graph for code predictor
batched_sampling.py Batched top-k/top-p sampling with per-slot EOS suppression

Benchmark Results (RTX 3060, 0.6B model, 8 requests)

Sequential (8x B=1):        ~9.0s  throughput=2.9x realtime
Batch-and-Wait (2x B=4):    ~5.4s  throughput=5.2x realtime
Continuous Batching (B=4):  ~10.6s  throughput=3.8x realtime  (4 replacements)

Key finding: On consumer GPUs (RTX 3060/4090), single-sample eager prefill for slot replacement (~500ms) costs more than it saves. Batch-and-wait is the optimal strategy for consumer GPUs. Continuous batching benefits high-end GPUs (H100/A100) where prefill latency is ~50ms.

Development History

Phase 1: Batched CUDA Graphs + Scheduler (v0.1.0)

The initial version established the core infrastructure:

  • BatchedPredictorGraph: CUDA-graphed batched code predictor — takes [B, 2, H] input (past_hidden + last_token_embed), produces [B, 15] codebook tokens in a single graph replay
  • BatchedTalkerGraph: CUDA-graphed batched talker decode — left-padding alignment so all samples share a single cache_position scalar, with per-sample differences handled by attention_mask and rope_deltas
  • batched_fast_generate(): Batched autoregressive loop with shared prefill, CUDA-graphed decode, and batched sampling
  • BatchScheduler: Async scheduler that collects requests with smart waiting (max_batch_size or max_wait_ms timeout), then dispatches to GPU

Key design decisions:

  • Left-padding alignment: All samples padded to max_seq_len in the batch, enabling a single shared cache_position for CUDA graph capture
  • StaticCache: Required for CUDA graphs (no dynamic memory allocation during replay)
  • Per-batch rope_deltas: Each sample carries its own rope_delta to compensate for padding differences

Phase 2: Batched Prefill (PR #2)

Replaced B separate talker prefill forward passes with a single batched forward:

  • Before: Loop B times calling talker.forward() individually, each producing a separate DynamicCache
  • After: Left-pad all inputs to max_seq_len, single talker.forward() call, one DynamicCache for the whole batch

Added prefill_kv_batched() to BatchedTalkerGraph for direct batched-cache-to-StaticCache copy (vs the original prefill_kv() which left-pads individual caches).

Also introduced set_generation_state() with padding_in_rope_deltas flag — when rope_deltas come from a batched prefill (with attention_mask), they already account for padding, so no additional offset is needed.

Phase 3: Continuous Batching (v0.2.0, current)

The most architecturally significant change — enabling mid-generation slot replacement:

On-the-fly attention mask computation (batched_talker_graph.py):

  • Replaced the precomputed attn_mask_table[position] lookup with per-step _compute_and_set_mask() using HuggingFace's create_causal_mask / create_sliding_window_causal_mask
  • This eliminated the need to rebuild the entire mask table when a slot is replaced — just update _attention_mask_2d[slot_idx] and recompute
  • Trade-off: Slightly more compute per decode step, but enables slot replacement without O(max_seq_len) mask rebuilds

Slot replacement primitives (batched_talker_graph.py):

  • replace_slot_kv(): Zero out a slot's KV cache in StaticCache, then write the new sample's left-padded KV data at [current_pos - prefill_len, current_pos)
  • update_slot_state(): Update _attention_mask_2d and rope_deltas for the replaced slot
  • Key discovery: HuggingFace StaticCache layers use layer.keys / layer.values (not layer.key_cache / layer.value_cache)

Per-slot generation state (continuous_generate.py):

  • gen_steps (per-slot): Controls trailing_text_hidden injection timing — each slot independently indexes into its text hidden sequence
  • slot_steps (per-slot): Controls min_new_tokens EOS suppression — replacement samples start from 0
  • Dynamic trailing_text_padded buffer with auto-expansion when replacement samples have longer text
  • id()-based completed tag tracking to avoid unhashable _PendingRequest dataclass in sets

Deferred audio decode (scheduler.py):

  • on_slot_done callback collects raw (tag, codec_ids.clone(), timing) tuples during generation
  • Audio decode (speech_tokenizer.decode()) runs after continuous_batched_generate() returns
  • This avoids CUDA state conflicts between the speech tokenizer and CUDA-graphed generation loop

Dual-mode scheduler:

  • continuous=False (default): Delegates to generate_voice_clone_batch() — proven faster on consumer GPUs
  • continuous=True: Uses continuous_batched_generate() with try_get_replacement callback
  • Dedicated GPU thread with queue.Queue for thread-safe async→sync bridging

Performance insight: Continuous batching is not universally beneficial. On RTX 3060, single-sample eager prefill for slot replacement (~500ms) pauses all active slots, costing more time than it saves. The continuous flag defaults to False to reflect this finding.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faster_qwen3_tts_batch-0.2.0.tar.gz (28.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

faster_qwen3_tts_batch-0.2.0-py3-none-any.whl (30.1 kB view details)

Uploaded Python 3

File details

Details for the file faster_qwen3_tts_batch-0.2.0.tar.gz.

File metadata

  • Download URL: faster_qwen3_tts_batch-0.2.0.tar.gz
  • Upload date:
  • Size: 28.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for faster_qwen3_tts_batch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 99e495a4dfb918351ad896c1c9fb0a52e139ceb398de6a33d41f0398609f0814
MD5 2953d8600765d039d76733f0b4ba0c76
BLAKE2b-256 75f6166b6b43b9e6d537656932f0742abb748f02d897f52b9b74437bdd739af8

See more details on using hashes here.

File details

Details for the file faster_qwen3_tts_batch-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for faster_qwen3_tts_batch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a365132809b42f8508a76f8b77906149f97e8fc57f241bb3b5a2729ecddffc29
MD5 74ccc98f504dcf90f790c9d5d7ae40e6
BLAKE2b-256 4433297dc7eff97edb19b80cc99274011600d6b3dc81bc6c0d2ec07597505b79

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page