Skip to main content

TGATE-V2: Faster Diffusion Through Temporal Attention Decomposition.

Project description

TGATE

TGATE accelerates inferences of [PixArtAlphaPipeline], [PixArtSigmaPipeline], [StableDiffusionPipeline], [StableDiffusionXLPipeline], and [StableVideoDiffusionPipeline] by skipping the calculation of self-attention and cross-attention once it converges. More details can be found at technical report.

🚀 Major Features

📖 Quick Start

🛠️ Installation

Start by installing TGATE:

pip install tgate

Requirements

  • pytorch>=2.0.0
  • diffusers>=0.29.0
  • DeepCache==0.1.1
  • transformers
  • accelerate

🌟 Usage

Accelerate PixArtAlphaPipeline with TGATE:

import torch
from diffusers import PixArtAlphaPipeline

pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-1024-MS", 
    torch_dtype=torch.float16,
)

+ from tgate import TgatePixArtAlphaLoader
+ gate_step = 15
+ sp_interval = 3
+ fi_interval = 1
+ warm_up = 2
+ inference_step = 25
+ pipe = TgatePixArtAlphaLoader(pipe).to("cuda")

+ image = pipe.tgate(
+         "An alpaca made of colorful building blocks, cyberpunk.",
+         gate_step=gate_step,
+         sp_interval=sp_interval,
+         fi_interval=fi_interval,
+         warm_up=warm_up,   
+         num_inference_steps=inference_step,
+ ).images[0]

Accelerate PixArtSigmaPipeline with TGATE:

import torch
from diffusers import PixArtSigmaPipeline

pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    torch_dtype=torch.float16,
)

+ from tgate import TgatePixArtSigmaLoader
+ gate_step = 15
+ sp_interval = 3
+ fi_interval = 1
+ warm_up = 2
+ inference_step = 25
+ pipe = TgatePixArtSigmaLoader(pipe).to("cuda")

+ image = pipe.tgate(
+         "an astronaut sitting in a diner, eating fries, cinematic, analog film.",
+         gate_step=gate_step,
+         sp_interval=sp_interval,
+         fi_interval=fi_interval,
+         warm_up=warm_up,   
+         num_inference_steps=inference_step,
+ ).images[0]

Accelerate StableDiffusionXLPipeline with TGATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)

+ from tgate import TgateSDXLLoader
+ gate_step = 10
+ sp_interval = 5
+ fi_interval = 1
+ warm_up = 2
+ inference_step = 25
+ pipe = TgateSDXLLoader(pipe).to("cuda")

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

+ image = pipe.tgate(
+         "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+         gate_step=gate_step,
+         sp_interval=sp_interval,
+         fi_interval=fi_interval,
+         warm_up=warm_up,  
+         num_inference_steps=inference_step
+ ).images[0]

Accelerate StableDiffusionXLPipeline with DeepCache and TGATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)

+ from tgate import TgateSDXLDeepCacheLoader
+ gate_step = 10
+ sp_interval = 1
+ fi_interval = 1
+ warm_up = 0
+ inference_step = 25
+ pipe = TgateSDXLDeepCacheLoader(
+        pipe,
+        cache_interval=3,
+        cache_branch_id=0,
+ ).to("cuda")

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)


+ image = pipe.tgate(
+         "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+         gate_step=gate_step,
+         sp_interval=sp_interval,
+         fi_interval=fi_interval,
+         warm_up=warm_up,  
+         num_inference_steps=inference_step
+ ).images[0]

Accelerate latent-consistency/lcm-sdxl with TGATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import UNet2DConditionModel, LCMScheduler
from diffusers import DPMSolverMultistepScheduler

unet = UNet2DConditionModel.from_pretrained(
    "latent-consistency/lcm-sdxl",
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    unet=unet,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

+ from tgate import TgateSDXLLoader
+ gate_step = 1
+ sp_interval = 1
+ fi_interval = 1
+ warm_up = 0
+ inference_step = 4
+ pipe = TgateSDXLLoader(pipe,lcm=True).to("cuda")

+ image = pipe.tgate(
+         "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+         gate_step=gate_step,
+         sp_interval=sp_interval,
+         fi_interval=fi_interval,
+         warm_up=warm_up,  
+         num_inference_steps=inference_step,
+ ).images[0]

TGATE also supports StableDiffusionPipeline, PixArt-alpha/PixArt-LCM-XL-2-1024-MS, and StableVideoDiffusionPipeline. More details can be found at here.

📄 Results

Model MACs Latency Zero-shot 10K-FID on MS-COCO
SD-XL 149.438T 53.187s 24.164
SD-XL w/ TGATE 95.988T 31.643s 22.917
Pixart-Alpha 107.031T 61.502s 37.983
Pixart-Alpha w/ TGATE 73.971T 36.650s 36.390
Pixart-Sigma 107.766T 60.467s 34.278
Pixart-Sigma w/ TGATE 74.420T 36.449s 32.927
DeepCache (SD-XL) 57.888T 19.931s 25.678
DeepCache w/ TGATE 43.868T 14.666s 24.511
LCM (SD-XL) 11.955T 3.805s 26.357
LCM w/ TGATE 11.171T 3.533s 26.902
LCM (Pixart-Alpha) 8.563T 4.733s 35.989
LCM w/ TGATE 7.623T 4.543s 35.843

The FID is computed on captions by PytorchFID.

The latency is tested on a 1080ti commercial card and diffusers v0.28.2.

The MACs and Params are calculated by calflops.

Citation

If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.

@article{tgate_v2,
  title={Faster Diffusion via Temporal Attention Decomposition},
  author={Liu, Haozhe and Zhang, Wentian and Xie, Jinheng and Faccio, Francesco and Xu, Mengmeng and Xiang, Tao and Shou, Mike Zheng and Perez-Rua, Juan-Manuel and Schmidhuber, J{\"u}rgen},
  journal={arXiv preprint arXiv:2404.02747},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tgate-1.0.0.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

tgate-1.0.0-py3-none-any.whl (47.1 kB view details)

Uploaded Python 3

File details

Details for the file tgate-1.0.0.tar.gz.

File metadata

  • Download URL: tgate-1.0.0.tar.gz
  • Upload date:
  • Size: 24.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for tgate-1.0.0.tar.gz
Algorithm Hash digest
SHA256 f7e17d924f7c610802c83aace385c8fc48d6d48607c85dbd08ab1d66fe86deed
MD5 10a1ce982212bea248fe547a33275052
BLAKE2b-256 9d3047414493922e6008058b396bc8d6e4e1d130835bac07aff06ce403d4acb2

See more details on using hashes here.

File details

Details for the file tgate-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: tgate-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 47.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for tgate-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bb1e9b324442c025bf9e9fa47a508730229ad6b08906fcddfb690e977425a6dd
MD5 c46034340f8a40171c7c0c6c21108c0e
BLAKE2b-256 ac41209f857c76f8cb9ea914fd3138eb7ab505b6a9feeda553879a6ec652521c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page