Pretrained BERT models for encoding clinical trial documents to compact embeddings.
Project description
Trial2Vec
Wang, Zifeng and Sun, Jimeng. (2022). Trial2Vec: Zero-Shot Clinical Trial Document Similarity Search using Self-Supervision. Findings of EMNLP'22.
News
- 12/8/2022: Support
download_embedding
that obtains the pretrained embedding only. It saves a lot of GPU/CPU memory! Please refer this example for detailed use cases.
from trial2vec import download_embedding
t2v_emb = download_embedding()
- 10/27/2022: Support
word_vector
andsentence_vector
!
# sentence vectors
inputs = ['I am a sentence', 'I am another sentence']
outputs = model.sentence_vector(inputs)
# torch.tensor w/ shape [2, 128]
# word vectors
inputs = ['I am a sentence', 'I am another sentence abcdefg xyz']
outputs = model.word_vector(inputs)
# {'word_embs': torch.tensor w/ shape [2, max_token, 128], 'mask': torch.tensor w/ shape [2, max_token]}
Usage
Get pretrained Trial2Vec model in three lines:
from trial2vec import Trial2Vec
model = Trial2Vec()
model.from_pretrained()
A jupyter example is shown at https://github.com/RyanWangZf/Trial2Vec/blob/main/example/demo_trial2vec.ipynb.
How to install
Install the correct PyTorch
version by referring to https://pytorch.org/get-started/locally/.
Then install Trial2Vec
by
# Recommended because it is update to date, small bugs will be kept fixed
pip install git+https://github.com/RyanWangZf/Trial2Vec.git
or
pip install trial2vec
Search similar trials
Use Trial2Vec
to search similar clinical trials:
# load demo data
from trial2vec import load_demo_data
data = load_demo_data()
# contains trial documents
test_data = {'x': data['x']}
# make prediction
pred = model.predict(test_data)
Encode trials
Use Trial2Vec
to encode clinical trial documents:
test_data = {'x': df} # contains trial documents
emb = model.encode(test_data) # make inference
# or just find the pre-encoded trial documents
emb = [model[nct_id] for test_data['x']['nct_id']]
Continue training
One can continue to train the pretrained models on new trials as
# just formulate trial documents as the format of `data`
data = load_demo_data()
model.fit(
{
'x':data['x'], # document dataframe
'fields':data['fields'], # attribute field columns
'ctx_fields':data['ctx_fields'], # context field columns
'tag': data['tag'], # nct_id is the unique tag for each trial
},
valid_data={
'x':data['x_val'],
'y':data['y_val']
},
)
# save
model.save_model('./finetuned-trial2vec')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for Trial2Vec-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 361c88f5a2fc7e74a5bcafceb409b109ce0ddce615721d412d960e684009a815 |
|
MD5 | 66e74cc9a6ae84878cc68e7d9ac3f733 |
|
BLAKE2b-256 | 0a0d15db1ee739865367d5bf4ea02426becaa489cb0188ffb33606db5fcdb9de |