Skip to main content

Terry toolkit longformer chinese

Project description

Longformer-chinese

All work is based on Longformer(https://github.com/allenai/longformer)

Longformer-chinese 提供了:基于BERT的中文预训练模型、在分类任务上的实现

WHAT'S DIFFERENT

Longformer-chinese 基于BERT框架进行修改,在embedding层会与原版的稍有区别。加载时使用longformer.longformer:

from longformer.longformer import *
config = LongformerConfig.from_pretrained('schen/longformer-chinese-base-4096')
model = Longformer.from_pretrained('schen/longformer-chinese-base-4096', config=config)

使用嵌入

from longformer.longformer import Longformer, LongformerConfig,LongformerEmbedding

# model=LongformerEmbedding('schen/longformer-chinese-base-4096',attention_mode='n2')

model=LongformerEmbedding('schen/longformer-chinese-base-4096')

# inputs = model.tokenizer("Hello, my dog is cute", return_tensors="pt",padding="max_length",truncation=True,max_length=model.tokenizer.model_max_length)
inputs = model.tokenizer("Hello, my dog is cute", return_tensors="pt",padding="max_length",truncation=True,max_length=40)

outputs = model(**inputs)
print("outputs",outputs)

print("outputs",outputs.keys())

print("outputs",outputs['last_hidden_state'].size())

使用schen/longformer-chinese-base-4096会自动从transformers下载预训练模型,也可以自行下载后替换成所在目录: https://huggingface.co/schen/longformer-chinese-base-4096

How to use

  1. Download pretrained model
  1. Install environment and code

    conda create --name longformer python=3.7
    conda activate longformer
    conda install cudatoolkit=10.0
    pip install git+https://github.com/allenai/longformer.git
    
  2. Run the model

    import torch
    from longformer.longformer import Longformer, LongformerConfig
    from longformer.sliding_chunks import pad_to_window_size
    from transformers import RobertaTokenizer
    
    config = LongformerConfig.from_pretrained('longformer-base-4096/') 
    # choose the attention mode 'n2', 'tvm' or 'sliding_chunks'
    # 'n2': for regular n2 attantion
    # 'tvm': a custom CUDA kernel implementation of our sliding window attention
    # 'sliding_chunks': a PyTorch implementation of our sliding window attention
    config.attention_mode = 'sliding_chunks'
    
    model = Longformer.from_pretrained('longformer-base-4096/', config=config)
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    tokenizer.model_max_length = model.config.max_position_embeddings
    
    SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000)  # long input document
    
    input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0)  # batch of size 1
    
    # TVM code doesn't work on CPU. Uncomment this if `config.attention_mode = 'tvm'`
    # model = model.cuda(); input_ids = input_ids.cuda()
    
    # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
    attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                         # classification: the <s> token
                                         # QA: question tokens
    
    # padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention
    input_ids, attention_mask = pad_to_window_size(
            input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)
    
    output = model(input_ids, attention_mask=attention_mask)[0]
    

Model pretraining

This notebook demonstrates our procedure for training Longformer starting from the RoBERTa checkpoint. The same procedure can be followed to get a long-version of other existing pretrained models.

TriviaQA

  • Training scripts: scripts/triviaqa.py
  • Pretrained large model: here (replicates leaderboard results)
  • Instructions: scripts/cheatsheet.txt

CUDA kernel

Our custom CUDA kernel is implemented in TVM. For now, the kernel only works on GPUs and Linux. We tested it on Ubuntu, Python 3.7, CUDA10, PyTorch >= 1.2.0. If it doesn't work for your environment, please create a new issue.

Compiling the kernel: We already include the compiled binaries of the CUDA kernel, so most users won't need to compile it, but if you are intersted, check scripts/cheatsheet.txt for instructions.

Known issues

Please check the repo issues for a list of known issues that we are planning to address soon. If your issue is not discussed, please create a new one.

Citing

If you use Longformer in your research, please cite Longformer: The Long-Document Transformer.

@article{Beltagy2020Longformer,
  title={Longformer: The Long-Document Transformer},
  author={Iz Beltagy and Matthew E. Peters and Arman Cohan},
  journal={arXiv:2004.05150},
  year={2020},
}

Longformer is an open-source project developed by the Allen Institute for Artificial Intelligence (AI2). AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

longformer_chinese-0.0.0.8.tar.gz (18.3 kB view details)

Uploaded Source

File details

Details for the file longformer_chinese-0.0.0.8.tar.gz.

File metadata

  • Download URL: longformer_chinese-0.0.0.8.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.2 CPython/3.9.5

File hashes

Hashes for longformer_chinese-0.0.0.8.tar.gz
Algorithm Hash digest
SHA256 a0a788574664c3023b5b6dc0561744a700875a5a9bf0fbc7cf31e6fefcf6f97b
MD5 768fea311f3d484b0321e12a233ef646
BLAKE2b-256 9d04db8787dbf6ada9e83bb23b00d7f5e070b3deb7c535550d6240ed78ee830e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page