Skip to main content

Memory and compute efficient DeBERTa models.

Project description

FlashDeBERTa 🦾 – Boost inference speed by 3-5x ⚡ and run DeBERTa models on long sequences 📚.

FlashDeBERTa is an optimized version of the DeBERTa model leveraging flash attention to implement a disentangled attention mechanism. It significantly reduces memory usage and latency, especially with long sequences. The project enables loading and running original DeBERTa models on tens of thousands of tokens without retraining, maintaining original accuracy.

Use Cases

DeBERTa remains one of the top-performing models for the following tasks:

  • Named Entity Recognition: It serves as the main backbone for models such as GLiNER, an efficient architecture for zero-shot information extraction.
  • Text Classification: DeBERTa is highly effective for supervised and zero-shot classification tasks, such as GLiClass.
  • Reranking: The model offers competitive performance compared to other reranking models, making it a valuable component in many RAG systems.

[!warning] This project is under active development and may contain bugs. Please create an issue if you encounter bugs or have suggestions for improvements.

Installation

First, install the package:

pip install flashdeberta -U

Then import the appropriate model heads for your use case and initialize the model from pretrained checkpoints:

from flashdeberta import FlashDebertaV2Model  # FlashDebertaV2ForSequenceClassification, FlashDebertaV2ForTokenClassification, etc.
from transformers import AutoTokenizer
import torch

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base").to('cuda')

# Tokenize input text
input_text = "Hello world!"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to('cuda')

# Model inference
outputs = model(input_ids)

In order to switch to eager attention implementation, initialise a model in the following way:

model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base", _attn_implementation='eager').to('cuda')

Kernel Tuning ⚙️

FlashDeBERTa automatically selects optimal kernel parameters based on your GPU. For advanced users who want to fine-tune performance, you can override these defaults using environment variables:

# Configure forward pass
export FLASHDEBERTA_FWD_BLOCK_M=128
export FLASHDEBERTA_FWD_BLOCK_N=64
export FLASHDEBERTA_FWD_NUM_STAGES=3
export FLASHDEBERTA_FWD_NUM_WARPS=4

# Configure backward pass (optional)
export FLASHDEBERTA_BWD_BLOCK_M=64
export FLASHDEBERTA_BWD_BLOCK_N=64
export FLASHDEBERTA_BWD_NUM_STAGES=2
export FLASHDEBERTA_BWD_NUM_WARPS=4

python train.py

Or set them directly in Python before importing:

import os
os.environ['FLASHDEBERTA_FWD_BLOCK_M'] = '128'
os.environ['FLASHDEBERTA_FWD_BLOCK_N'] = '64'
os.environ['FLASHDEBERTA_FWD_NUM_STAGES'] = '3'
os.environ['FLASHDEBERTA_FWD_NUM_WARPS'] = '4'

from flashdeberta import FlashDebertaV2Model

Note: All four parameters must be set together to take effect. Typical values: BLOCK_M/N ∈ {32, 64, 128}, num_stages ∈ {1, 2, 3, 4}, num_warps ∈ {4, 8}.

Benchmarks

While context-to-position and position-to-context biases still require quadratic memory, our flash attention implementation reduces overall memory requirements to nearly linear. This efficiency is particularly impactful for longer sequences. Starting from 512 tokens, FlashDeBERTa achieves more than a 50% performance improvement, and at 4k tokens, it's over 5 times faster than naive implementations.

benchmarking

Below, you can find a breakdown of inference and training speed gains together with growing memory efficiency for the DeBERTa-based model evaluated during inference and training on different batch sizes and sequence lengths. gliclass_benchmarking

Future Work

  • Implement backward kernels.
  • Train DeBERTa models on 8,192-token sequences using high-quality data.
  • Integrate FlashDeBERTa into GLiNER and GLiClass.
  • Train multi-modal DeBERTa models.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flashdeberta-0.0.7.tar.gz (41.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flashdeberta-0.0.7-py3-none-any.whl (44.7 kB view details)

Uploaded Python 3

File details

Details for the file flashdeberta-0.0.7.tar.gz.

File metadata

  • Download URL: flashdeberta-0.0.7.tar.gz
  • Upload date:
  • Size: 41.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for flashdeberta-0.0.7.tar.gz
Algorithm Hash digest
SHA256 4f6bc35ea5b34781fd30d88c03c21bfad7aa3c60f6fc829f1a60af8085e3ab87
MD5 dba1f23dc821a734873238275ce5485b
BLAKE2b-256 ce1f58989194b12e800b12fae998ea64f649905dbb8dd94754ce806095cf645d

See more details on using hashes here.

File details

Details for the file flashdeberta-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: flashdeberta-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 44.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for flashdeberta-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 54a3db83173a8342256a7e20d62800c16509758c2ff6a7b80005adf6cdbeeba1
MD5 6dd5acc0dcc98928589134ecf3e05a67
BLAKE2b-256 5be49917300ab40ecbefde285d7d127022807a13c31baf8487fcdbb25a2230fd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page