Skip to main content

Recurrent Attention Networks

Project description

RAN: Recurrent Attention Network

📢 This project is still in the works in order to make long document modeling easier.

RAN is released under the MIT license. PyPI version PyPI Downloads http://makeapullrequest.com https://arxiv.org/abs/2306.06843

⬇️ Installation

stable

python -m pip install -U rannet

latest

python -m pip install git+https://github.com/4AI/RAN.git

environment

  • ⭐ tensorflow>2.0,<=2.10 🤗 export TF_KERAS=1
  • tensorflow>=1.14,<2.0 🤗 Keras==2.3.1

🏛️ Pretrained Models

Lang Google Drive Baidu NetDrive
EN base base[code: djkj]
CN base | small base[code: e47w] | small[code: mdmg]

🚀 Quick Tour

🈶 w/ pretrained models

Extract semantic feature

set apply_cell_transform=False to extract semantic feature.

import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path,
    checkpoint_path=ckpt_path,
    return_sequences=False,
    apply_cell_transform=False
)
text = 'input text'
tok = tokenizer.encode(text)
vec = rannet_model.predict(np.array([tok.ids]))

For the classification task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_sequences=False)
output = rannet_model.output  # (B, D)
output = L.Dropout(0.1)(output)
output = L.Dense(2, activation='softmax')(output)
model = keras.models.Model(rannet_model.input, output)
model.summary()

For the sequence task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_gpc=False)
output = rannet_model.output  # (B, L, D)
rannet_model.summary()

🈚 w/o pretrained models

Embed the RAN (a Keras layer) into your network.

from rannet import RAN

ran = RAN(head_num=8,
          head_size=256,
          window_size=256,
          min_window_size=16,
          activation='swish',
          kernel_initializer='glorot_normal',
          apply_lm_mask=False,
          apply_seq2seq_mask=False,
          apply_memory_review=True,
          dropout_rate=0.0,
          cell_initializer_type='zero')
output, cell = ran(X)

📚 Citation

If you use our code in your research, please cite our work:

@inproceedings{li-etal-2023-ran,
    title = "Recurrent Attention Networks for Long-text Modeling",
    author = "Li, Xianming and Li, Zongxi and Luo, Xiaotian and Xie, Haoran and Lee, Xing and Zhao, Yingbin and Wang, Fu Lee and Li, Qing",
    booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
    year = "2023",
    publisher = "Association for Computational Linguistics"
}

📬 Contact

Please contact us at 1) for code problems, create a GitHub issue; 2) for paper problems, email xmlee97@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rannet-0.2.0.tar.gz (31.2 kB view hashes)

Uploaded Source

Built Distribution

rannet-0.2.0-py3-none-any.whl (32.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page