Recurrent Attention Networks
Project description
RAN: Recurrent Attention Network
📢 This project is still in the works in order to make long document modeling easier.
⬇️ Installation
stable
python -m pip install -U rannet
latest
python -m pip install git+https://github.com/4AI/RAN.git
environment
- ⭐ tensorflow>2.0,<=2.10 🤗
export TF_KERAS=1
- tensorflow>=1.14,<2.0 🤗 Keras==2.3.1
🏛️ Pretrained Models
V3 Models
🎯 compatible with: rannet>0.2.1
Lang | Google Drive | Baidu NetDrive |
---|---|---|
EN | base | base[code: udts] |
Chinese Models are still pretraining...
V2 Models
🎯 compatible with: rannet<=0.2.1
Lang | Google Drive | Baidu NetDrive |
---|---|---|
EN | base | base[code: djkj] |
CN | base | small | base[code: e47w] | small[code: mdmg] |
V1 Models
V1 models are not open.
🚀 Quick Tour
🈶 w/ pretrained models
Extract semantic feature
set return_sequences=False
to extract semantic feature.
import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path,
checkpoint_path=ckpt_path,
return_sequences=False,
apply_cell_transform=False,
cell_pooling='mean'
)
text = 'input text'
tok = tokenizer.encode(text)
vec = rannet_model.predict(np.array([tok.ids]))
For the classification task
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path, checkpoint_path=ckpt_path, return_sequences=False)
output = rannet_model.output # (B, D)
output = L.Dropout(0.1)(output)
output = L.Dense(2, activation='softmax')(output)
model = keras.models.Model(rannet_model.input, output)
model.summary()
For the sequence task
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path, checkpoint_path=ckpt_path, return_cell=False)
output = rannet_model.output # (B, L, D)
rannet_model.summary()
🈚 w/o pretrained models
Embed the RAN
(a Keras layer) into your network.
from rannet import RAN
ran = RAN(head_num=8,
head_size=256,
window_size=256,
min_window_size=16,
activation='swish',
kernel_initializer='glorot_normal',
apply_lm_mask=False,
apply_seq2seq_mask=False,
apply_memory_review=True,
dropout_rate=0.0,
cell_initializer_type='zero')
output, cell = ran(X)
w/ history
import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path,
checkpoint_path=ckpt_path,
return_sequences=False,
apply_cell_transform=False,
return_history=True, # return history
cell_pooling='mean',
with_cell=True, # with cell input
)
rannet_model.summary()
text = 'sentence 1'
tok = tokenizer.encode(text)
init_cell = np.zeros((1, 768)) # 768 is embedding size
vec, history = rannet_model.predict([np.array([tok.ids]), init_cell])
text2 = 'sentence 2'
tok = tokenizer.encode(text2)
vec2, history = rannet_model.predict([np.array([tok.ids]), history]) # input history of sentence 1
📚 Citation
If you use our code in your research, please cite our work:
@inproceedings{li-etal-2023-recurrent,
title = "Recurrent Attention Networks for Long-text Modeling",
author = "Li, Xianming and
Li, Zongxi and
Luo, Xiaotian and
Xie, Haoran and
Lee, Xing and
Zhao, Yingbin and
Wang, Fu Lee and
Li, Qing",
booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
month = jul,
year = "2023",
publisher = "Association for Computational Linguistics",
pages = "3006--3019",
}
📬 Contact
Please contact us at 1) for code problems, create a GitHub issue; 2) for paper problems, email xmlee97@gmail.com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rannet-0.3.1.tar.gz
(31.8 kB
view details)
Built Distribution
rannet-0.3.1-py3-none-any.whl
(32.4 kB
view details)
File details
Details for the file rannet-0.3.1.tar.gz
.
File metadata
- Download URL: rannet-0.3.1.tar.gz
- Upload date:
- Size: 31.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26687a7c97ba90e339e60719d540794a5342b5f49911071d3c755ccea57a9273 |
|
MD5 | 91965aeca884c2b79aa70c628283687b |
|
BLAKE2b-256 | 5242d1b42a8ba33846934dd0ad077b210adb716e85335c3541c90b5538c4a1db |
File details
Details for the file rannet-0.3.1-py3-none-any.whl
.
File metadata
- Download URL: rannet-0.3.1-py3-none-any.whl
- Upload date:
- Size: 32.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf11113f72ef4027577489cc47a080ca8cc243c67850e2e4d2a6bad1c5be0e61 |
|
MD5 | b7102e33650979a17c6ce3b699369761 |
|
BLAKE2b-256 | 66deca796f02398de73da07fdb248dba5dd01818bf3e188a9fdb61e72e770ed4 |