fast SoulX-FlashTalk for RTX 4090
Project description
Fast-FlashTalk
基于 FlashTalk 的高性能推理优化版本,专为 RTX 4090 显卡优化,在保持生成质量的同时显著降低显存占用并提升推理速度,实测可达 2 倍加速。
优化项
1. DiT 动态参数加载
实现了 DiT 模型参数在 CPU/GPU 之间的动态调度。通过 num_persistent_param_in_dit 参数控制常驻 GPU 的参数量,超出部分自动 offload 到 CPU,推理时按需加载,从而在有限显存下运行 14B 参数的 DiT 模型。
2. GemLite A8W8 量化
使用 GemLite 对 DiT 模型进行 A8W8 int8 动态量化,将线性层的权重和激活值量化为 int8,大幅降低显存占用和计算开销。部分关键模块(time_embedding、head 等)被排除在量化之外以保证精度。
3. apply_rope 算子优化
使用 flash_attention 提供的 apply_rotary_emb 替代原始的逐元素复数运算实现 RoPE,利用 CUDA kernel 融合加速旋转位置编码的计算。
4. SageAttention
集成 SageAttention 替代标准注意力计算,支持 sageattn 和 sageattn_varlen 两种模式。对短序列(< 512)自动回退至 flash_attn 以保持最优性能。
5. T5 Cache
对 T5 编码器的推理结果使用 lru_cache 进行缓存(默认 maxsize=20),相同文本 prompt 的重复推理直接返回缓存结果,避免重复计算。
安装
需要 Python 3.11+、支持 CUDA 的 NVIDIA GPU(推荐显存 24GB 及以上以运行 14B 模型),以及 ffmpeg。
# Debian/Ubuntu
sudo apt update && sudo apt install -y ffmpeg
# macOS(Homebrew
brew install ffmpeg
# conda
conda install -c conda-forge ffmpeg
pip install fast-flashtalk
模型与数据准备
推理前需自行下载模型权重:
modelscope download Soul-AILab/SoulX-FlashTalk-14B --local_dir checkpoints/Soul-AILab/SoulX-FlashTalk-14B
modelscope download TencentGameMate/chinese-wav2vec2-base --local_dir checkpoints/TencentGameMate/chinese-wav2vec2-basechinese-wav2vec2-base
使用说明
首次创建 FlashTalkPipeline 时会从磁盘加载多路权重、完成量化与显存调度初始化;首次调用 generate 时还可能包含 CUDA 预热、部分算子首次执行等一次性开销,因此第一次运行整体会明显慢于后续同进程内的推理,属正常现象。同一进程内再次生成通常会快很多。
最小示例
from fast_flashtalk import Audio, FlashTalkPipeline, Image
checkpoint_dir = "path/to/SoulX-FlashTalk-14B"
wav2vec_dir = "path/to/chinese-wav2vec2-base"
pipeline = FlashTalkPipeline(
checkpoint_dir=checkpoint_dir,
wav2vec_dir=wav2vec_dir,
num_persistent_param_in_dit=15_000_000_000,
)
image = Image(uri="path/to/portrait.png")
audio = Audio(uri="path/to/speech.wav")
video = pipeline.generate(
input_prompt="人物与场景描述,用于引导画面风格与内容。",
audio=audio,
image=image,
)
# 添加音频与视频合并
#video.merge_audio(audio)
video.save("test.mp4")
Image、Audio 使用 uri 指向本地图片或音频文件;音频会按管线内配置的采样率重采样。
FlashTalkPipeline 主要参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
checkpoint_dir |
str |
必填 | FlashTalk 模型权重根目录 |
wav2vec_dir |
str |
必填 | Wav2Vec2 模型本地目录 |
num_persistent_param_in_dit |
int |
10_000_000_000 |
常驻 GPU 的 DiT 参数个数上限,显存紧张时适当调小 |
generate 主要参数
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
input_prompt |
str |
必填 | 文本提示,描述人物、场景、镜头等 |
audio |
Audio |
必填 | 驱动口型与节奏的音频 |
image |
Image |
必填 | 条件图像(人物/画面参考) |
audio_encode_mode |
"stream" | "once" |
"once" |
音频编码方式:once 整段编码后按块切分;stream 按流式块编码,更省内存 |
返回值类型为 osc_data.video.Video,可通过 .data 等属性访问帧数据(具体以 osc_data 文档为准)。
许可证与致谢
上游实现与权重归属请参考 FlashTalk 及相应模型许可。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fast_flashtalk-0.1.1.post0.tar.gz.
File metadata
- Download URL: fast_flashtalk-0.1.1.post0.tar.gz
- Upload date:
- Size: 196.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
519c7bfe4d812e6f0876d474ae2a71845d197e06ec693cdfd55571cbbfdbe598
|
|
| MD5 |
54dcd26461fcf64e9173b8b0eafb8acf
|
|
| BLAKE2b-256 |
7bb549c4dbfcc1657039f9b602b21201df85e679e0011f9957ad8efdda117176
|
File details
Details for the file fast_flashtalk-0.1.1.post0-py3-none-any.whl.
File metadata
- Download URL: fast_flashtalk-0.1.1.post0-py3-none-any.whl
- Upload date:
- Size: 222.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
301bac4bcc6ab8df8e0005cf37f5b81fc500ad2ec3d5dc24c2136090acb1467a
|
|
| MD5 |
77a908ee72d81c2e80eed4bb6f6d0229
|
|
| BLAKE2b-256 |
2ecdf6baaf1a95f833d5e00928da23bbe3373b357ede0e9b4648566569bf96a3
|