Skip to main content

To convert HuggingFace models to their LSG variant

Project description

LSG Attention: Extrapolation of pretrained Transformers to long sequences

ArXiv paper
Accepted as a conference paper in PAKDD 2023.

Requires transformers >= 4.36.1

Optional package to convert models:

pip install lsg-converter

Compatible models

Note that for non-compatible models, the LSG attention layer is available, see LSG-Layer and lsg_converter/attention_layers for more informations.

This package can convert various HuggingFace checkpoints and architectures to their LSG (Local-Sparse-Global) variant to handle long sequences.
Available model types:

  • ALBERT ["albert"]
  • BART ["bart"] (encoder attention modified only)
  • BARThez ["barthez"] (encoder attention modified only)
  • BERT ["bert"]
  • CamemBERT ["camembert"]
  • DistilBERT ["distilbert"]
  • Electra ["electra"]
  • mBART ["mbart"] (not tested extensively, encoder attention modified only)
  • Pegasus ["pegasus"] (not tested extensively, encoder attention modified only)
  • RoBERTa ["roberta"]
  • XLM-RoBERTa ["xlm-roberta"]

Some converted checkpoints are available here.

attn

Efficiency

Memory and training speed for a binary classification task with 4x4096 tokens batches and Adam (RTX 8000).

Models Seconds per step Memory w/ attn dropout Memory w/o attn dropout
Longformer-base 3.22 s/step 34.38 Gb 32.83 Gb
BigBird-RoBERTa-base 2.85 s/step 38.13 Gb 38.13 Gb (no effect)
LSG-RoBERTa 256/0 1.40 s/step 32.92 Gb 24.80 Gb
LSG-RoBERTa 128/128 (norm) 1.51 s/step 33.80 Gb 27.52 Gb
LSG-RoBERTa 32/32 (norm) 1.20 s/step 24.53 Gb 22.53 Gb

attn

Convert checkpoint to LSG

Models can be converted with or without the lsg-converter package.

With package

BERT example with the lsg-converter package:

from lsg_converter import LSGConverter

converter = LSGConverter(max_sequence_length=4096)

# Example 1
model, tokenizer = converter.convert_from_pretrained("bert-base-uncased", num_global_tokens=7)
print(type(model))
# <class 'lsg_converter.bert.modeling_lsg_bert.LSGBertForMaskedLM'>

# Example 2
model, tokenizer = converter.convert_from_pretrained("bert-base-uncased", architecture="BertForSequenceClassification", use_auth_token=True)
print(type(model))
# <class 'lsg_converter.bert.modeling_lsg_bert.LSGBertForSequenceClassification'>

Without package

Use convert_checkpoint.py to convert a model (check model_type from config.json).
The architecture of the model is inferred from the config file, but you can specify a different one if the config is incorrect (which can happen for BART models), see python convert_checkpoint.py --help.
To test the converted model, add --run_test (experimental).

RoBERTa example (from RobertaForMaskedLM to RobertaForSequenceClassification) without package:

git clone https://github.com/ccdv-ai/convert_checkpoint_to_lsg.git
cd convert_checkpoint_to_lsg

export MODEL_TO_CONVERT=roberta-base
export MODEL_NAME=lsg-roberta-base
export MAX_LENGTH=4096

python convert_checkpoint.py \
    --initial_model $MODEL_TO_CONVERT \
    --model_name $MODEL_NAME \
    --model_kwargs "{'mask_first_token': true, 'sparsity_type': 'lsh', 'block_size': 32}" \
    --architecture RobertaForSequenceClassification \
    --max_sequence_length $MAX_LENGTH

Model Usage

Works with the AutoClass.

from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load created model
MODEL_NAME = "lsg-roberta-base"
SENTENCE = "This is a test sentence."

model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

inputs = tokenizer(SENTENCE, return_tensors="pt")
model(**inputs)
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load created model
MODEL_NAME = "lsg-roberta-base"
SENTENCE = "This is a test sentence."

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

inputs = tokenizer(SENTENCE, return_tensors="pt")
model(**inputs)

LSG-Layer

You can find a self contained implementation of LSG attention (with no additional params), see lsg_converter/attention_layers.
Doesn't work for Cross Attention because the local context is ambiguous to define in this case.

Usage:

import torch
from lsg_converter.attention_layers import BlockLocalSelfAttention, LSGSelfAttention

# batch, num_heads, sequence length, hidden_size
n, h, t, d = 2, 4, 58, 32  

Q, K, V = torch.randn(n, h, t, d), torch.randn(n, h, t, d), torch.randn(n, h, t, d)
attention_mask = torch.zeros(n, 1, 1, t).float()

# Only block local attention with 1 global connection
attn = BlockLocalSelfAttention(block_size=16, compute_global_attention=True, is_causal=False, attention_dropout_prob=0.1)

# LSG attention with 1 global connection
attn = LSGSelfAttention(block_size=32, sparsity_factor=8, sparsity_type="bos_pooling", compute_global_attention=True, is_causal=True)

# expect (batch, num_heads, sequence_length, hidden_size) inputs,
# attention_mask is (batch, 1, 1, sequence_length) 
# causal mask is built on the fly but (batch, 1, sequence_length, sequence_length) mask is possible
outputs = attn(Q, K, V, attention_mask)

print(outputs.shape)
> torch.Size([2, 4, 58, 32])

Example: replacing Self Attention in GPT2 (from Huggingface):

from transformers.models.gpt2 import * 
from lsg_converter.attention_layers import BlockLocalSelfAttention

class GPT2BlockLocalAttention(modeling_gpt2.GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)
        self.attn = BlockLocalSelfAttention(block_size=32, compute_global_attention=True, is_causal=True)

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        return self.attn(query, key, value, attention_mask), None
    
modeling_gpt2.GPT2Attention = GPT2BlockLocalAttention

Note that for generation (inference on causal modeling), full attention is used after the first step.
This may change in the future.

LSG Attention

LSG relies on 3 components: block local attention, sparse attention and prepended global attention.

Parameters

You can change various parameters like :

  • local block size (block_size=128)
  • sparse block size (sparse_block_size=128)
  • sparsity factor (sparsity_factor=2)
  • mask_first_token (mask first token since it is redundant with the first global token)
  • the number of global tokens (num_global_tokens=1)
  • see config.json file

Sparse selection type

There are 6 different sparse selection patterns. The best type is task dependent.
If sparse_block_size=0 or sparsity_type="none", only local attention is considered.
Note that for sequences with length < 2*block_size, the type has no effect.

  • sparsity_type="bos_pooling" (new)
    • weighted average pooling using the BOS token
    • Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
    • Additional parameters:
      • None
  • sparsity_type="norm", select highest norm tokens
    • Works best for a small sparsity_factor (2 to 4)
    • Additional parameters:
      • None
  • sparsity_type="pooling", use average pooling to merge tokens
    • Works best for a small sparsity_factor (2 to 4)
    • Additional parameters:
      • None
  • sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
    • Works best for a large sparsity_factor (4+)
    • LSH relies on random projections, thus inference may differ slightly with different seeds
    • Additional parameters:
      • lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
  • sparsity_type="stride", use a striding mecanism per head
    • Each head will use different tokens strided by sparsify_factor
    • Not recommended if sparsify_factor > num_heads
  • sparsity_type="block_stride", use a striding mecanism per head
    • Each head will use block of tokens strided by sparsify_factor
    • Not recommended if sparsify_factor > num_heads

Experiments

See experiments/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lsg-converter-0.1.9.tar.gz (117.8 kB view details)

Uploaded Source

Built Distribution

lsg_converter-0.1.9-py3-none-any.whl (129.3 kB view details)

Uploaded Python 3

File details

Details for the file lsg-converter-0.1.9.tar.gz.

File metadata

  • Download URL: lsg-converter-0.1.9.tar.gz
  • Upload date:
  • Size: 117.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.18

File hashes

Hashes for lsg-converter-0.1.9.tar.gz
Algorithm Hash digest
SHA256 b2a11aa334d5f650ad5b00f666133f89a21c15f9e23b4ffb2859548e2ab473fd
MD5 0a989b6ea773947fabdbe30f0b87aeb7
BLAKE2b-256 5ba5f1f9d16858419810aa1288d45308b169598d898107f8dd1ec85bc6c3b6e7

See more details on using hashes here.

File details

Details for the file lsg_converter-0.1.9-py3-none-any.whl.

File metadata

File hashes

Hashes for lsg_converter-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 efae4882065b84773c805b5db86be8608ba2e353fb347a2b5580d43d2642974b
MD5 8d319df7a129a174024a4ad3680f00d9
BLAKE2b-256 f531cb6014a6c1a61247e986d0f132359c2bcf3a0ac97c6bd627f88bd5239c87

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page