Recurrent Attention Networks
Project description
RAN: Recurrent Attention Network
📢 This project is still in the works in order to make long document modeling easier.
⬇️ Installation
stable
python -m pip install -U rannet
latest
python -m pip install git+https://github.com/4AI/RAN.git
environment
- ⭐ tensorflow>2.0,<=2.10 🤗
export TF_KERAS=1
- tensorflow>=1.14,<2.0 🤗 Keras==2.3.1
🏛️ Pretrained Models
V3 Models
🎯 compatible with: rannet>0.2.1
Lang | Google Drive | Baidu NetDrive |
---|---|---|
EN | base | base[code: udts] |
Chinese Models are still pretraining...
V2 Models
🎯 compatible with: rannet<=0.2.1
Lang | Google Drive | Baidu NetDrive |
---|---|---|
EN | base | base[code: djkj] |
CN | base | small | base[code: e47w] | small[code: mdmg] |
V1 Models
V1 models are not open.
🚀 Quick Tour
🈶 w/ pretrained models
Extract semantic feature
set return_sequences=False
to extract semantic feature.
import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path,
checkpoint_path=ckpt_path,
return_sequences=False,
apply_cell_transform=False,
cell_pooling='mean'
)
text = 'input text'
tok = tokenizer.encode(text)
vec = rannet_model.predict(np.array([tok.ids]))
For the classification task
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path, checkpoint_path=ckpt_path, return_sequences=False)
output = rannet_model.output # (B, D)
output = L.Dropout(0.1)(output)
output = L.Dense(2, activation='softmax')(output)
model = keras.models.Model(rannet_model.input, output)
model.summary()
For the sequence task
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path, checkpoint_path=ckpt_path, return_cell=False)
output = rannet_model.output # (B, L, D)
rannet_model.summary()
🈚 w/o pretrained models
Embed the RAN
(a Keras layer) into your network.
from rannet import RAN
ran = RAN(head_num=8,
head_size=256,
window_size=256,
min_window_size=16,
activation='swish',
kernel_initializer='glorot_normal',
apply_lm_mask=False,
apply_seq2seq_mask=False,
apply_memory_review=True,
dropout_rate=0.0,
cell_initializer_type='zero')
output, cell = ran(X)
w/ history
import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer
vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)
rannet, rannet_model = RanNet.load_rannet(
config_path=config_path,
checkpoint_path=ckpt_path,
return_sequences=False,
apply_cell_transform=False,
return_history=True, # return history
cell_pooling='mean',
with_cell=True, # with cell input
)
rannet_model.summary()
text = 'sentence 1'
tok = tokenizer.encode(text)
init_cell = np.zeros((1, 768)) # 768 is embedding size
vec, history = rannet_model.predict([np.array([tok.ids]), init_cell])
text2 = 'sentence 2'
tok = tokenizer.encode(text2)
vec2, history = rannet_model.predict([np.array([tok.ids]), history]) # input history of sentence 1
📚 Citation
If you use our code in your research, please cite our work:
@inproceedings{li-etal-2023-recurrent,
title = "Recurrent Attention Networks for Long-text Modeling",
author = "Li, Xianming and
Li, Zongxi and
Luo, Xiaotian and
Xie, Haoran and
Lee, Xing and
Zhao, Yingbin and
Wang, Fu Lee and
Li, Qing",
booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
month = jul,
year = "2023",
publisher = "Association for Computational Linguistics",
pages = "3006--3019",
}
📬 Contact
Please contact us at 1) for code problems, create a GitHub issue; 2) for paper problems, email xmlee97@gmail.com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rannet-0.3.1.tar.gz
.
File metadata
- Download URL: rannet-0.3.1.tar.gz
- Upload date:
- Size: 31.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
26687a7c97ba90e339e60719d540794a5342b5f49911071d3c755ccea57a9273
|
|
MD5 |
91965aeca884c2b79aa70c628283687b
|
|
BLAKE2b-256 |
5242d1b42a8ba33846934dd0ad077b210adb716e85335c3541c90b5538c4a1db
|
File details
Details for the file rannet-0.3.1-py3-none-any.whl
.
File metadata
- Download URL: rannet-0.3.1-py3-none-any.whl
- Upload date:
- Size: 32.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
bf11113f72ef4027577489cc47a080ca8cc243c67850e2e4d2a6bad1c5be0e61
|
|
MD5 |
b7102e33650979a17c6ce3b699369761
|
|
BLAKE2b-256 |
66deca796f02398de73da07fdb248dba5dd01818bf3e188a9fdb61e72e770ed4
|