Skip to main content

Up to 5x faster Qwen3-TTS inference through Triton kernel fusion

Project description

Qwen3-TTS-Triton

CI PyPI Python License

Up to 5x faster Qwen3-TTS inference through Triton kernel fusion and TurboQuant KV cache.

Korean (ํ•œ๊ตญ์–ด) | Benchmark Results

[!NOTE] This project has only been tested on RTX 5090 (Blackwell, sm_120) with WSL2 (CUDA 12.8, PyTorch nightly cu128). Triton kernels are architecture-agnostic (no sm_120-specific code), so they should work on other NVIDIA GPUs (A100, H100, RTX 4090, etc.), but this has not been verified. If you test on a different GPU, please open an issue or PR with your results!


Qwen3-TTS-Triton replaces performance-critical operators in Qwen3-TTS 1.7B with hand-written Triton kernels. Inspired by Liger Kernel (LinkedIn), each kernel fuses multiple HBM round-trips into a single pass, reducing memory traffic without any additional VRAM usage.

It can also be combined with faster-qwen3-tts (CUDA Graph + static KV-cache) as a Hybrid mode for maximum throughput. Hybrid+TQ is the current release-grade TurboQuant path. Base+TQ and Triton+TQ remain experimental until they pass the full Tier 3 gate.

๐Ÿ’ก Why Triton?

  • ๐Ÿชถ Lightweight & Portable โ€” No serving infrastructure needed. Just pip install qwen3-tts-triton and call apply_triton_kernels(). Works in standalone scripts, ComfyUI nodes, Gradio apps, or any Python environment.
  • ๐ŸŽฒ Faster Iteration on Stochastic TTS โ€” Qwen3-TTS generates different output each run. For best results, generate multiple candidates and pick the best one. With Hybrid mode's ~5x speedup, you can produce 5 candidates in the time it used to take for 1 โ€” more takes, better results.

๐ŸŒฑ Why Optimize Qwen3-TTS?

Qwen3-TTS is rapidly becoming the backbone for next-generation TTS models. Darwin-TTS blends just 3% of general LLM weights back into the Qwen3-TTS-1.7B talker โ€” a 10-second, training-free operation โ€” to produce emotionally expressive speech. Projects like OmniVoice further demonstrate the Qwen3 architecture's versatility for multilingual, zero-shot TTS. As more derivative models build on Qwen3-TTS, kernel-level speedups here propagate to the entire ecosystem โ€” every model that shares the same 28-layer transformer talker benefits from these Triton kernels with zero code changes.

โœจ Highlights

  • โšก 4 Fused Triton Kernels โ€” RMSNorm, SwiGLU, M-RoPE, Norm+Residual
  • ๐ŸŽฏ 7 Inference Modes โ€” Base, Base+TQ, Triton, Triton+TQ, Faster, Hybrid, Hybrid+TQ
  • ๐Ÿ—œ๏ธ TurboQuant KV Cache โ€” INT4/INT3 calibration-free KV cache quantization for VRAM savings
  • ๐Ÿ”ฌ 3-Tier Verification โ€” Kernel correctness โ†’ Model parity โ†’ E2E quality distribution
  • ๐Ÿ’พ Zero Extra VRAM โ€” Pure kernel fusion, no model changes
  • ๐Ÿ”Œ Drop-in Patching โ€” Single apply_triton_kernels() call, weight sharing via monkey-patch
  • ๐Ÿ“Š Streamlit Dashboard โ€” Side-by-side comparison UI with live metrics

๐Ÿ“ฆ Install

Requirements: Python 3.12+, CUDA 12.8+, NVIDIA GPU (8GB+ VRAM). Tested on WSL2 (Windows Subsystem for Linux 2).

From PyPI

# 1. Install PyTorch with CUDA support first
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128

# 2. Install qwen3-tts-triton
pip install qwen3-tts-triton

From Source (development)

# Install UV (if not installed)
curl -LsSf https://astral.sh/uv/install.sh | sh

# Clone and setup
git clone https://github.com/newgrit1004/qwen3-tts-triton.git
cd qwen3-tts-triton
make setup  # uv sync --all-extras --dev + pre-commit install + git config

UV handles virtual environments automatically โ€” no need to manually activate a venv. All commands use the uv run prefix (e.g., uv run pytest, uv run python script.py). PyTorch is installed from the cu128 index automatically via pyproject.toml.

Dependency Groups

uv sync                 # Core (triton, transformers, faster-qwen3-tts, streamlit, plotly)
uv sync --extra eval    # + Quality evaluation (cohere-transcribe, jiwer, resemblyzer)
uv sync --extra dev     # + Dev tools (ruff, pytest, pre-commit)
uv sync --extra all     # Everything

๐Ÿš€ Quick Start

[!TIP] On first run, the model (~3.5GB) is automatically downloaded from HuggingFace. To download in advance: huggingface-cli download Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice

Triton Mode

from qwen3_tts_triton import TritonRunner
import soundfile as sf

runner = TritonRunner()
runner.load_model()  # Downloads model on first run (~3.5GB)

result = runner.generate(
    text="Hello, this is optimized with Triton kernels.",
    language="English",
    speaker="vivian",
)

# Save audio
sf.write("output.wav", result["audio"], result["sample_rate"])
print(f"Generated in {result['time_s']:.2f}s, VRAM: {result['peak_vram_gb']:.2f}GB")

runner.unload_model()

Hybrid Mode (Triton + CUDA Graph, ~5x faster)

from qwen3_tts_triton import TritonFasterRunner
import soundfile as sf

runner = TritonFasterRunner()
runner.load_model()  # Triton patches applied before CUDA Graph capture

result = runner.generate(
    text="Hybrid mode: CUDA Graph + Triton fusion.",
    language="English",
    speaker="vivian",
)

sf.write("output.wav", result["audio"], result["sample_rate"])
runner.unload_model()

Hybrid+TQ Mode (Triton + CUDA Graph + TurboQuant KV Cache)

from qwen3_tts_triton import TritonFasterRunner
import soundfile as sf

runner = TritonFasterRunner(enable_turboquant=True, tq_bits=4)
runner.load_model()  # Triton patches + TurboQuant KV cache injected before CUDA Graph capture

result = runner.generate(
    text="Hybrid+TQ mode: CUDA Graph + Triton fusion + INT4 KV cache.",
    language="English",
    speaker="vivian",
)

sf.write("output.wav", result["audio"], result["sample_rate"])
print(f"RTF: {result['rtf']:.2f}x, VRAM: {result['peak_vram_gb']:.2f}GB")
runner.unload_model()

Note: TurboQuant quantizes the KV cache to INT4 (or INT3) at each decode step and is compatible with all runner modes via the enable_turboquant=True flag. In the current full Tier 3 release gate, Hybrid+TQ passes while Base+TQ and Triton+TQ remain caveated.

๐Ÿ“Š Streamlit Dashboard

make ui  # http://localhost:8501

The dashboard provides:

  • ๐Ÿ”„ Side-by-side inference comparison across all modes
  • ๐Ÿ“ˆ Live metrics (TTFA, RTF, Total Time, Peak VRAM)
  • ๐Ÿ“‰ Plotly charts for visual comparison
  • โœ… 3-Tier verification result cards

๐ŸŽง Audio Samples

Pre-generated samples comparing inference modes (custom voice + voice cloning).

Mode Directory
Base (PyTorch) assets/audio_samples/base/
Base+TQ assets/audio_samples/base+tq/
Triton assets/audio_samples/triton/
Triton+TQ assets/audio_samples/triton+tq/
Faster (CUDA Graph) assets/audio_samples/faster/
Hybrid (Faster+Triton) assets/audio_samples/hybrid/
Hybrid+TQ assets/audio_samples/hybrid+tq/

Each directory contains custom voice samples (5 Korean + 5 English) and voice cloning samples using LJSpeech reference audio (Public Domain).

Use make ui โ†’ Audio Samples tab for side-by-side playback and comparison. Regenerate: make generate-samples (GPU required).

โšก Triton Kernels

All kernels target the Qwen3-TTS Talker (28-layer Transformer, hidden_size=2048, intermediate=6144).

Kernel What It Fuses HBM Savings File
RMSNorm variance + normalize + scale in SRAM 4โ†’1 round-trips kernels/rms_norm.py
SwiGLU silu(gate) * up โ€” eliminates intermediate tensor 3โ†’1 round-trips kernels/swiglu.py
M-RoPE 3D positional encoding (sections=[24,20,20]) In-place compute kernels/rope.py
Fused Norm+Residual residual + x then RMSNorm in one kernel 2 kernels โ†’ 1 kernels/fused_norm_residual.py

Additionally, TurboQuant (kernels/turboquant.py) provides INT4/INT3 KV cache quantization with calibration-free Lloyd-Max codebooks and Hadamard rotation for outlier suppression.

๐Ÿ”Œ How Patching Works

apply_triton_kernels() performs in-place monkey-patching:

  1. RMSNorm modules โ†’ replaced with TritonRMSNorm (shares original weights, zero copy)
  2. MLP forward โ†’ patched to use triton_swiglu_forward (fused gate+up projection)
  3. Decoder layer forward โ†’ patched for fused residual addition + normalization
from qwen3_tts_triton.models.patching import apply_triton_kernels

# Patches all 28 decoder layers in-place (patch counts logged via logging)
apply_triton_kernels(model)
Advanced: Manual Patching

If you want to apply Triton kernels to a model loaded outside the Runner API:

from qwen_tts import Qwen3TTSModel
from qwen3_tts_triton.models.patching import apply_triton_kernels
import torch

model = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
    device_map="cuda:0",
    dtype=torch.bfloat16,
)

# Patch the internal nn.Module (not the wrapper)
apply_triton_kernels(model.model)

wavs, sr = model.generate_custom_voice(
    text="Hello, this is optimized with Triton kernels.",
    language="English",
    speaker="vivian",
)

For Hybrid mode with manual patching, use find_patchable_model() to resolve the internal module:

from faster_qwen3_tts import FasterQwen3TTS
from qwen3_tts_triton.models.patching import apply_triton_kernels, find_patchable_model

model = FasterQwen3TTS.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", device="cuda"
)

# FasterQwen3TTS wraps multiple layers: model.model.model reaches the nn.Module
internal = find_patchable_model(model.model)
apply_triton_kernels(internal)

๐Ÿ”ฌ 3-Tier Verification

Inspired by Liger Kernel and industry practices from vLLM and SGLang.

Tier What Threshold Time Command
1. Kernel Kernel correctness + CPU-only regression guards bf16: 0.05, fp16: 1e-3 ~48s (RTX 5090 WSL2) make test (197 tests)
2. Model Layer-by-layer cosine similarity (2 pairs) > 0.95 at layers 0,7,14,21,27 ~46s (RTX 5090 WSL2) make test-parity
3. E2E Output quality distribution (UTMOS, CER, Speaker Sim) See below 15-80min make eval-fast

Tier 3 Thresholds

Each model generates independently, then task-level metrics are compared via distribution analysis (not pair-level waveform comparison โ€” stochastic TTS makes this unreliable).

Metric Threshold Rationale
UTMOS delta |mean| < 0.3 F5-TTS independent generation variance
UTMOS floor Both > 2.5 Absolute quality lower bound
CER delta |mean| < 0.05 SGLang 1-5% tolerance
Speaker Similarity mean > 0.75 Qwen3-TTS SIM > 0.79
Mann-Whitney U p > 0.05 (full mode) Non-parametric distribution equivalence

Running Verification

make test          # Tier 1: Kernel tests
make test-parity   # Tier 2: Model parity (GPU required)
make verify        # Tier 1 + 2 + existing Tier 3 artifact report
make eval-fast     # Tier 3: Fast (~15min, Cohere Transcribe, 1 run/utterance)
make eval-full     # Tier 3: Full (~80min, Cohere Transcribe, 3 runs, Mann-Whitney)
make verify-all    # Run eval-full, then build the 3-Tier report

๐Ÿ“‹ Latest Results

โœ… Tier 1: All kernel tests PASS

โœ… Tier 2: 2 pairs tested, all PASS

Pair A โ€” Base โ†” Triton (cosine > 0.997):

Layer Cosine Sim
L0 0.999995
L7 0.999977
L14 0.999852
L21 0.999177
L27 0.997900
Output 0.997156

Pair B โ€” Faster โ†” Hybrid (cosine > 0.997): PASS

FP accumulation naturally decreases similarity across 28 layers โ€” this is expected behavior for fused kernels that change operation order.

๐Ÿ“Š Benchmarks

Hybrid (Faster+Triton) achieves 5.0x faster inference than PyTorch baseline at equivalent VRAM on RTX 5090.

๐Ÿ—๏ธ Optimization Modes

graph TD
    A["Base (PyTorch eager)"] -->|"+Triton kernel fusion"| B["Triton (~1.1x)"]
    A -->|"+CUDA Graph + Static Cache"| C["Faster (~4.0x)"]
    C -->|"+Triton kernel fusion"| D["Hybrid (~5.0x)"]
    A -->|"+TurboQuant KV"| A2["Base+TQ"]
    B -->|"+TurboQuant KV"| B2["Triton+TQ"]
    D -->|"+TurboQuant KV"| D2["Hybrid+TQ"]

    style D fill:#f96,stroke:#333,stroke-width:2px,color:#000

TurboQuant (+TQ) variants share the same INT4 KV-cache path, but the current full Tier 3 release gate passes only for Hybrid+TQ.

make bench-kernels  # Per-kernel micro-benchmarks (PyTorch vs Triton)
make bench-e2e      # End-to-end inference (all runners)
make bench          # Default suite (kernels + speed + fast quality + report)
make profile        # torch.profiler trace
Hardware & Methodology
Item Spec
GPU NVIDIA RTX 5090 (Blackwell, sm_120, 32GB)
CUDA 12.8
PyTorch nightly (cu128)
Triton 3.2.0
Model Qwen3-TTS-12Hz-1.7B (1.7B params)
OS WSL2 (Linux 5.15)
Python 3.12
Dtype bfloat16
Batch Size 1

Kernel benchmarks: triton.testing.do_bench(), batch=1, seq_len=512, hidden=2048. E2E benchmarks: torch.cuda.Event timing, 3 warmup + 20 measured runs per text. RTF (Real-Time Factor) = audio_duration / generation_time. RTF > 1 means faster-than-real-time.

โšก Kernel Micro-Benchmarks

RTX 5090, bf16, batch=1, seq_len=512, hidden=2048. Run make bench-kernels to reproduce.

Kernel PyTorch (us) Triton (us) Speedup Compile (s) HBM Savings
RMSNorm 40.9 7.4 5.51x 0.34 4โ†’1 trips
SwiGLU 19.4 16.0 1.21x 0.00 3โ†’1 trips
M-RoPE 367.9 37.3 9.87x 0.02 In-place
Fused Norm+Residual 40.6 9.0 4.49x 0.00 2โ†’1 kernels

๐ŸŽ๏ธ E2E Inference

RTX 5090, bf16, 2 texts (ko + en), 3 warmup + 20 runs each. Run make bench-e2e to reproduce.

Mode Load Time Latency (ko) Latency (en) RTF (ko) RTF (en) vs Base Peak VRAM
Base (PyTorch) 17.5s 4,615 ms 5,081 ms 0.88x 0.90x 1.0x 4.03 GB
Base+TQ 8.3s 9,030 ms 5,745 ms 0.82x 0.79x 0.7x 4.07 GB
Triton 7.9s 4,130 ms 4,462 ms 1.00x 1.00x 1.1x 4.03 GB
Triton+TQ 7.4s 8,045 ms 5,877 ms 0.93x 0.88x 0.7x 4.09 GB
Faster 9.2s 1,136 ms 1,265 ms 3.49x 3.52x 4.0x 4.28 GB
Hybrid (Faster+Triton) 6.0s 886 ms 1,042 ms 4.20x 4.26x 5.0x 4.32 GB
Hybrid+TQ 6.5s 944 ms 1,032 ms 4.27x 4.25x 4.9x 4.33 GB

Triton/Triton+TQ/Hybrid/Hybrid+TQ use the default partial patch range [0, 24); the final 4 decoder layers stay in PyTorch for pronunciation stability.

๐ŸŽต Audio Quality (Tier 3)

Official release quality numbers use full mode as the canonical Tier 3 result.

Runner UTMOS CER Speaker Sim Status
Base (ref) 3.40 ยฑ 0.78 0.04 ยฑ 0.06 - ref
Base+TQ (base+tq) 3.17 ยฑ 0.81 0.42 ยฑ 2.02 0.82 FAIL
Triton (triton) 3.40 ยฑ 0.76 0.04 ยฑ 0.07 0.85 PASS
Triton+TQ (triton+tq) 3.04 ยฑ 0.83 0.43 ยฑ 1.49 0.83 FAIL
Faster (faster) 3.42 ยฑ 0.75 0.04 ยฑ 0.04 0.83 PASS
Hybrid (hybrid) 3.38 ยฑ 0.78 0.04 ยฑ 0.06 0.83 PASS
Hybrid+TQ (hybrid+tq) 3.32 ยฑ 0.78 0.05 ยฑ 0.07 0.83 PASS

Release caveats (full mode):

  • base+tq: FAIL - CER delta 0.3801 > 0.05; Mann-Whitney p=0.0340 < 0.05
  • triton+tq: FAIL - UTMOS delta 0.3565 > 0.3; CER delta 0.3865 > 0.05; Mann-Whitney p=0.0015 < 0.05

Run make eval-full to reproduce. Treat fast mode as a smoke check, not the release authority.

Disclaimer: Benchmarks measured on a single RTX 5090. Results vary with GPU model, driver version, system load, and input text length. Run make bench on your hardware for accurate numbers.

๐Ÿ“ Project Structure

qwen3-tts-triton/
โ”œโ”€โ”€ src/
โ”‚   โ””โ”€โ”€ qwen3_tts_triton/           # PyPI package
โ”‚       โ”œโ”€โ”€ __init__.py              # Public API + __version__
โ”‚       โ”œโ”€โ”€ py.typed                 # PEP 561 type marker
โ”‚       โ”œโ”€โ”€ kernels/                 # Triton GPU kernels
โ”‚       โ”‚   โ”œโ”€โ”€ rms_norm.py          # Fused RMSNorm
โ”‚       โ”‚   โ”œโ”€โ”€ swiglu.py            # Fused SwiGLU
โ”‚       โ”‚   โ”œโ”€โ”€ rope.py              # Fused M-RoPE
โ”‚       โ”‚   โ”œโ”€โ”€ fused_norm_residual.py # Fused Norm+Residual
โ”‚       โ”‚   โ””โ”€โ”€ turboquant.py        # TurboQuant INT4/INT3 KV cache
โ”‚       โ””โ”€โ”€ models/                  # Model runners & patching
โ”‚           โ”œโ”€โ”€ patching.py          # Monkey-patch logic (partial patching support)
โ”‚           โ”œโ”€โ”€ base_runner.py       # Standard PyTorch (+ TurboQuant option)
โ”‚           โ”œโ”€โ”€ triton_runner.py     # Triton-optimized
โ”‚           โ”œโ”€โ”€ faster_runner.py     # faster-qwen3-tts wrapper
โ”‚           โ””โ”€โ”€ triton_faster_runner.py # Hybrid (faster + Triton + TQ option)
โ”œโ”€โ”€ tests/                           # Verification tests
โ”‚   โ”œโ”€โ”€ kernels/                     # Tier 1: Kernel correctness
โ”‚   โ””โ”€โ”€ test_model_parity.py         # Tier 2: Model parity (2 pairs)
โ”œโ”€โ”€ benchmark/                       # Benchmarking suite
โ”‚   โ””โ”€โ”€ results/                     # Saved benchmark JSON outputs
โ”œโ”€โ”€ ui/                              # Streamlit dashboard
โ”œโ”€โ”€ docs/                            # Documentation
โ”œโ”€โ”€ pyproject.toml                   # Project config (UV + hatchling)
โ”œโ”€โ”€ uv.lock                          # Locked dependencies
โ””โ”€โ”€ Makefile                         # Development commands

๐Ÿ› ๏ธ Development

make format      # Ruff formatting
make lint        # Ruff linting
make lint-fix    # Ruff auto-fix
make test        # pytest (Tier 1)
make test-cov    # pytest + coverage
make check       # lint + test
make pre-commit  # All pre-commit hooks
make clean       # Clear caches

๐Ÿง  Qwen3-TTS Talker Architecture

Parameter Value
Model Qwen3-TTS-12Hz-1.7B-CustomVoice
Hidden Size 2048
Attention Heads 16 (GQA, kv_heads=8)
Head Dim 128
Intermediate Size 6144
Layers 28
RMS Norm Eps 1e-6
Position Encoding M-RoPE (sections=[24,20,20])
Activation SwiGLU

๐Ÿ”„ Compatibility

๐ŸŽค Voice Modes by Runner

Feature Base Base+TQ Triton Triton+TQ Faster Hybrid Hybrid+TQ
Custom Voice Yes Yes Yes Yes Yes Yes Yes
Voice Cloning Yes Yes Yes Yes Yes Yes Yes
Voice Design -- -- -- -- Yes Yes Yes
Streaming -- -- -- -- Yes Yes Yes
Dynamic Shape Yes Yes Yes Yes Yes Yes Yes
bfloat16 / float16 Yes Yes Yes Yes Yes Yes Yes
TurboQuant KV -- Yes -- Yes -- -- Yes

๐Ÿ’ป Platform Support

Platform Supported
Linux Yes
Windows WSL2 Yes

๐Ÿ—บ๏ธ TODO

  • Docker deployment
  • TurboQuant INT4/INT3 KV cache quantization (Base+TQ, Triton+TQ, Hybrid+TQ modes)
  • Partial Patching โ€” selective layer patching for pronunciation accuracy
  • SageAttention integration โ€” INT8 quantized attention
  • ComfyUI-Qwen3-TTS-Triton โ€” ComfyUI custom node
  • Multi-GPU architecture testing (A100, H100, RTX 4090, etc.)

๐Ÿ“„ License

Apache-2.0

๐Ÿ™ Acknowledgments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

qwen3_tts_triton-0.2.0.tar.gz (338.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

qwen3_tts_triton-0.2.0-py3-none-any.whl (47.6 kB view details)

Uploaded Python 3

File details

Details for the file qwen3_tts_triton-0.2.0.tar.gz.

File metadata

  • Download URL: qwen3_tts_triton-0.2.0.tar.gz
  • Upload date:
  • Size: 338.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for qwen3_tts_triton-0.2.0.tar.gz
Algorithm Hash digest
SHA256 89aec2b1c2f6ef86ae3712e30f91a23505433b3f34a9c3b518b9168ced740509
MD5 2ce39a911e7fce229374f4dcccbd0f64
BLAKE2b-256 b1fab0b3c7f18db1d8896903b008a0b7db00ed6ed5744ae0c9d2b56de780468b

See more details on using hashes here.

Provenance

The following attestation bundles were made for qwen3_tts_triton-0.2.0.tar.gz:

Publisher: publish.yml on newgrit1004/qwen3-tts-triton

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file qwen3_tts_triton-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for qwen3_tts_triton-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bc566fdb717cf509744182babdb2f89a4689cc2d615d9a1d5c527f7217de69a5
MD5 02587c158ece16fcda7218a20a7048c0
BLAKE2b-256 69f9ad55b349ea278b40d56baa0f9a352e44d5333143809dd840640fbae50827

See more details on using hashes here.

Provenance

The following attestation bundles were made for qwen3_tts_triton-0.2.0-py3-none-any.whl:

Publisher: publish.yml on newgrit1004/qwen3-tts-triton

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page