Skip to main content

fast SoulX-FlashTalk for RTX 4090

Project description

Fast-FlashTalk

基于 FlashTalk 的高性能推理优化版本,专为 RTX 4090 显卡优化,在保持生成质量的同时显著降低显存占用并提升推理速度,实测可达 2 倍加速

优化项

1. DiT 动态参数加载

实现了 DiT 模型参数在 CPU/GPU 之间的动态调度。通过 num_persistent_param_in_dit 参数控制常驻 GPU 的参数量,超出部分自动 offload 到 CPU,推理时按需加载,从而在有限显存下运行 14B 参数的 DiT 模型。

2. GemLite A8W8 量化

使用 GemLite 对 DiT 模型进行 A8W8 int8 动态量化,将线性层的权重和激活值量化为 int8,大幅降低显存占用和计算开销。部分关键模块(time_embedding、head 等)被排除在量化之外以保证精度。

3. apply_rope 算子优化

使用 flash_attention 提供的 apply_rotary_emb 替代原始的逐元素复数运算实现 RoPE,利用 CUDA kernel 融合加速旋转位置编码的计算。

4. SageAttention

集成 SageAttention 替代标准注意力计算,支持 sageattnsageattn_varlen 两种模式。对短序列(< 512)自动回退至 flash_attn 以保持最优性能。

5. T5 Cache

对 T5 编码器的推理结果使用 lru_cache 进行缓存(默认 maxsize=20),相同文本 prompt 的重复推理直接返回缓存结果,避免重复计算。

安装

需要 Python 3.11+、支持 CUDA 的 NVIDIA GPU(推荐显存 24GB 及以上以运行 14B 模型),以及 ffmpeg

# Debian/Ubuntu
sudo apt update && sudo apt install -y ffmpeg
# macOS(Homebrew
brew install ffmpeg
# conda
conda install -c conda-forge ffmpeg
pip install fast-flashtalk

模型与数据准备

推理前需自行准备:

资源 说明
FlashTalk 权重目录 FlashTalk / SoulX 发布结构一致,包含 DiT、VAE、CLIP、T5 等子目录与配置
Wav2Vec2 例如 chinese-wav2vec2-base,传入本地目录,from_pretrained(..., local_files_only=True) 加载

将上述路径分别传给 checkpoint_dirwav2vec_dir

使用说明

首次创建 FlashTalkPipeline 时会从磁盘加载多路权重、完成量化与显存调度初始化;首次调用 generate 时还可能包含 CUDA 预热、部分算子首次执行等一次性开销,因此第一次运行整体会明显慢于后续同进程内的推理,属正常现象。同一进程内再次生成通常会快很多。

最小示例

from fast_flashtalk import Audio, FlashTalkPipeline, Image

checkpoint_dir = "path/to/SoulX-FlashTalk-14B"
wav2vec_dir = "path/to/chinese-wav2vec2-base"

pipeline = FlashTalkPipeline(
    checkpoint_dir=checkpoint_dir,
    wav2vec_dir=wav2vec_dir,
    num_persistent_param_in_dit=15_000_000_000,
)

image = Image(uri="path/to/portrait.png")
audio = Audio(uri="path/to/speech.wav")

video = pipeline.generate(
    input_prompt="人物与场景描述,用于引导画面风格与内容。",
    audio=audio,
    image=image,
)
# 返回 osc_data.Video;未指定 video_save_path 时默认写入 sample_results/res_<时间戳>.mp4

ImageAudio 使用 uri 指向本地图片或音频文件;音频会按管线内配置的采样率重采样。

FlashTalkPipeline 主要参数

参数 类型 默认值 说明
checkpoint_dir str 必填 FlashTalk 模型权重根目录
wav2vec_dir str 必填 Wav2Vec2 模型本地目录
device str "cuda" 推理设备
cpu_offload bool True 是否将部分模块放在 CPU 上以节省显存
num_timesteps int 1000 扩散时间步数
use_timestep_transform bool True 是否对时间步做变换调度
num_persistent_param_in_dit int 10_000_000_000 常驻 GPU 的 DiT 参数个数上限,显存紧张时适当调小
quantize bool True 是否对 DiT 做 GemLite A8W8 量化
use_usp bool False 分布式 USP 策略(一般单机保持 False

generate 主要参数

参数 类型 默认值 说明
input_prompt str 必填 文本提示,描述人物、场景、镜头等
audio Audio 必填 驱动口型与节奏的音频
image Image 必填 条件图像(人物/画面参考)
audio_encode_mode "stream" | "once" "once" 音频编码方式:once 整段编码后按块切分;stream 按流式块编码,更省内存
video_save_path str | None None 输出 MP4 路径;为 None 时写入 sample_results/res_<时间戳>.mp4
merge_video_audio bool True 是否将生成画面与原始音频合并到输出文件
force_9_16 bool False 是否强制竖屏 9:16 输出

返回值类型为 osc_data.video.Video,可通过 .data 等属性访问帧数据(具体以 osc_data 文档为准)。

许可证与致谢

上游实现与权重归属请参考 FlashTalk 及相应模型许可。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_flashtalk-0.1.0.tar.gz (200.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fast_flashtalk-0.1.0-py3-none-any.whl (225.9 kB view details)

Uploaded Python 3

File details

Details for the file fast_flashtalk-0.1.0.tar.gz.

File metadata

  • Download URL: fast_flashtalk-0.1.0.tar.gz
  • Upload date:
  • Size: 200.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for fast_flashtalk-0.1.0.tar.gz
Algorithm Hash digest
SHA256 648e1f09f11c49a1bfc00abf5441b17b1ebe0024a9f4ed3b1ce91fc4bb516636
MD5 3c18f1370112e71756ae8f1a50eb106b
BLAKE2b-256 caac77027bab459d47571cb5d9868b0c781de5885aa0c241fff8a14d3103053f

See more details on using hashes here.

File details

Details for the file fast_flashtalk-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: fast_flashtalk-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 225.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for fast_flashtalk-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 416bc7bd14b8f232711baef2df36dd2ca5c0eedad63e17df06c29615c8840f5d
MD5 85884005afeaa934633124a93290bbe4
BLAKE2b-256 6f426dda1b187347f059925af92d1b484ffb831bcb8af8a2ab1cfcbd751198bd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page