Memory and compute efficient DeBERTa models.
Project description
FlashDeBERTa 🦾 – Boost inference speed by 3-5x ⚡ and run DeBERTa models on long sequences 📚.
FlashDeBERTa is an optimized version of the DeBERTa model leveraging flash attention to implement a disentangled attention mechanism. It significantly reduces memory usage and latency, especially with long sequences. The project enables loading and running original DeBERTa models on tens of thousands of tokens without retraining, maintaining original accuracy.
Use Cases
DeBERTa remains one of the top-performing models for the following tasks:
- Named Entity Recognition: It serves as the main backbone for models such as GLiNER, an efficient architecture for zero-shot information extraction.
- Text Classification: DeBERTa is highly effective for supervised and zero-shot classification tasks, such as GLiClass.
- Reranking: The model offers competitive performance compared to other reranking models, making it a valuable component in many RAG systems.
[!warning] This project is under active development and may contain bugs. Please create an issue if you encounter bugs or have suggestions for improvements.
Installation
First, install the package:
pip install flashdeberta -U
Then import the appropriate model heads for your use case and initialize the model from pretrained checkpoints:
from flashdeberta import FlashDebertaV2Model # FlashDebertaV2ForSequenceClassification, FlashDebertaV2ForTokenClassification, etc.
from transformers import AutoTokenizer
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base").to('cuda')
# Tokenize input text
input_text = "Hello world!"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to('cuda')
# Model inference
outputs = model(input_ids)
In order to switch to eager attention implementation, initialise a model in the following way:
model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base", _attn_implementation='eager').to('cuda')
Kernel Tuning ⚙️
FlashDeBERTa automatically selects optimal kernel parameters based on your GPU. For advanced users who want to fine-tune performance, you can override these defaults using environment variables:
# Configure forward pass
export FLASHDEBERTA_FWD_BLOCK_M=128
export FLASHDEBERTA_FWD_BLOCK_N=64
export FLASHDEBERTA_FWD_NUM_STAGES=3
export FLASHDEBERTA_FWD_NUM_WARPS=4
# Configure backward pass (optional)
export FLASHDEBERTA_BWD_BLOCK_M=64
export FLASHDEBERTA_BWD_BLOCK_N=64
export FLASHDEBERTA_BWD_NUM_STAGES=2
export FLASHDEBERTA_BWD_NUM_WARPS=4
python train.py
Or set them directly in Python before importing:
import os
os.environ['FLASHDEBERTA_FWD_BLOCK_M'] = '128'
os.environ['FLASHDEBERTA_FWD_BLOCK_N'] = '64'
os.environ['FLASHDEBERTA_FWD_NUM_STAGES'] = '3'
os.environ['FLASHDEBERTA_FWD_NUM_WARPS'] = '4'
from flashdeberta import FlashDebertaV2Model
Note: All four parameters must be set together to take effect. Typical values: BLOCK_M/N ∈ {32, 64, 128}, num_stages ∈ {1, 2, 3, 4}, num_warps ∈ {4, 8}.
Benchmarks
While context-to-position and position-to-context biases still require quadratic memory, our flash attention implementation reduces overall memory requirements to nearly linear. This efficiency is particularly impactful for longer sequences. Starting from 512 tokens, FlashDeBERTa achieves more than a 50% performance improvement, and at 4k tokens, it's over 5 times faster than naive implementations.
Below, you can find a breakdown of inference and training speed gains together with growing memory efficiency for the DeBERTa-based model evaluated during inference and training on different batch sizes and sequence lengths.
Future Work
- Implement backward kernels.
- Train DeBERTa models on 8,192-token sequences using high-quality data.
- Integrate FlashDeBERTa into GLiNER and GLiClass.
- Train multi-modal DeBERTa models.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flashdeberta-0.0.7.tar.gz.
File metadata
- Download URL: flashdeberta-0.0.7.tar.gz
- Upload date:
- Size: 41.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4f6bc35ea5b34781fd30d88c03c21bfad7aa3c60f6fc829f1a60af8085e3ab87
|
|
| MD5 |
dba1f23dc821a734873238275ce5485b
|
|
| BLAKE2b-256 |
ce1f58989194b12e800b12fae998ea64f649905dbb8dd94754ce806095cf645d
|
File details
Details for the file flashdeberta-0.0.7-py3-none-any.whl.
File metadata
- Download URL: flashdeberta-0.0.7-py3-none-any.whl
- Upload date:
- Size: 44.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
54a3db83173a8342256a7e20d62800c16509758c2ff6a7b80005adf6cdbeeba1
|
|
| MD5 |
6dd5acc0dcc98928589134ecf3e05a67
|
|
| BLAKE2b-256 |
5be49917300ab40ecbefde285d7d127022807a13c31baf8487fcdbb25a2230fd
|