Skip to main content

Recurrent Attention Networks

Project description

RAN: Recurrent Attention Network

📢 This project is still in the works in order to make long document modeling easier.

RAN is released under the MIT license. PyPI version PyPI Downloads http://makeapullrequest.com https://arxiv.org/abs/2306.06843

The framework of RAN

⬇️ Installation

stable

python -m pip install -U rannet

latest

python -m pip install git+https://github.com/4AI/RAN.git

environment

  • ⭐ tensorflow>2.0,<=2.10 🤗 export TF_KERAS=1
  • tensorflow>=1.14,<2.0 🤗 Keras==2.3.1

🏛️ Pretrained Models

V3 Models

🎯 compatible with: rannet>0.2.1

Lang Google Drive Baidu NetDrive
EN base base[code: udts]

Chinese Models are still pretraining...

V2 Models

🎯 compatible with: rannet<=0.2.1

Lang Google Drive Baidu NetDrive
EN base base[code: djkj]
CN base | small base[code: e47w] | small[code: mdmg]

V1 Models

V1 models are not open.

🚀 Quick Tour

🈶 w/ pretrained models

Extract semantic feature

set return_sequences=False to extract semantic feature.

import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path,
    checkpoint_path=ckpt_path,
    return_sequences=False,
    apply_cell_transform=False,
    cell_pooling='mean'
)
text = 'input text'
tok = tokenizer.encode(text)
vec = rannet_model.predict(np.array([tok.ids]))

For the classification task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_sequences=False)
output = rannet_model.output  # (B, D)
output = L.Dropout(0.1)(output)
output = L.Dense(2, activation='softmax')(output)
model = keras.models.Model(rannet_model.input, output)
model.summary()

For the sequence task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_cell=False)
output = rannet_model.output  # (B, L, D)
rannet_model.summary()

🈚 w/o pretrained models

Embed the RAN (a Keras layer) into your network.

from rannet import RAN

ran = RAN(head_num=8,
          head_size=256,
          window_size=256,
          min_window_size=16,
          activation='swish',
          kernel_initializer='glorot_normal',
          apply_lm_mask=False,
          apply_seq2seq_mask=False,
          apply_memory_review=True,
          dropout_rate=0.0,
          cell_initializer_type='zero')
output, cell = ran(X)

w/ history

import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path,
    checkpoint_path=ckpt_path,
    return_sequences=False,
    apply_cell_transform=False,
    return_history=True,  # return history
    cell_pooling='mean',
    with_cell=True,  # with cell input
)
rannet_model.summary()

text = 'sentence 1'
tok = tokenizer.encode(text)
init_cell = np.zeros((1, 768))  # 768 is embedding size
vec, history = rannet_model.predict([np.array([tok.ids]), init_cell])

text2 = 'sentence 2'
tok = tokenizer.encode(text2)
vec2, history = rannet_model.predict([np.array([tok.ids]), history])  # input history of sentence 1

📚 Citation

If you use our code in your research, please cite our work:

@inproceedings{li-etal-2023-recurrent,
    title = "Recurrent Attention Networks for Long-text Modeling",
    author = "Li, Xianming  and
      Li, Zongxi  and
      Luo, Xiaotian  and
      Xie, Haoran  and
      Lee, Xing  and
      Zhao, Yingbin  and
      Wang, Fu Lee  and
      Li, Qing",
    booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
    month = jul,
    year = "2023",
    publisher = "Association for Computational Linguistics",
    pages = "3006--3019",
}

📬 Contact

Please contact us at 1) for code problems, create a GitHub issue; 2) for paper problems, email xmlee97@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rannet-0.3.1.tar.gz (31.8 kB view details)

Uploaded Source

Built Distribution

rannet-0.3.1-py3-none-any.whl (32.4 kB view details)

Uploaded Python 3

File details

Details for the file rannet-0.3.1.tar.gz.

File metadata

  • Download URL: rannet-0.3.1.tar.gz
  • Upload date:
  • Size: 31.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.10

File hashes

Hashes for rannet-0.3.1.tar.gz
Algorithm Hash digest
SHA256 26687a7c97ba90e339e60719d540794a5342b5f49911071d3c755ccea57a9273
MD5 91965aeca884c2b79aa70c628283687b
BLAKE2b-256 5242d1b42a8ba33846934dd0ad077b210adb716e85335c3541c90b5538c4a1db

See more details on using hashes here.

File details

Details for the file rannet-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: rannet-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 32.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.10

File hashes

Hashes for rannet-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bf11113f72ef4027577489cc47a080ca8cc243c67850e2e4d2a6bad1c5be0e61
MD5 b7102e33650979a17c6ce3b699369761
BLAKE2b-256 66deca796f02398de73da07fdb248dba5dd01818bf3e188a9fdb61e72e770ed4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page