Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models
Project description
torch-text-similarity
Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models. Thanks for the elegent implementations of @Andriy Mulyar, who has published a lot of useful codes.
Installation
Install with pip:
pip install torch-text-similarity
Use
Maps batches of sentence pairs to real-valued scores in the range [0,5]
import torch
from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset
learner = TextSimilarityLearner(batch_size=10,
model_name='web-bert-similarity',
loss_func=torch.nn.MSELoss(),
learning_rate=5e-5,
weight_decay=0,
device=torch.device('cuda:0'))
train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='./data/train.csv')
learner.load_train_data(train_dataset)
learner.train(epoch=1)
predictions = learner.predict([('The patient is sick.', 'Grass is green.'),
('A prescription of acetaminophen 325 mg was given.', ' The patient was given Tylenol.')
])
print(predictions)
Make submission to a semantic text similarity competition
import torch
import pandas as pd
from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset
learner = TextSimilarityLearner(batch_size=10,
model_name='web-bert-similarity',
loss_func=torch.nn.MSELoss(),
learning_rate=5e-5,
weight_decay=0,
device=torch.device('cuda:0'))
train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='/home/temp/Data/kaggle/data/train.csv')
learner.load_train_data(train_dataset)
learner.train(epoch=1)
test_data = pd.read_csv('./data/test.csv')
preds_list = []
for i, row in test_data.iterrows():
text_a = row['text_a']
text_b = row['text_b']
preds = learner.predict([(text_a, text_b)])[0]
preds_list.append(preds)
submission = pd.DataFrame({"id": range(len(preds_list)), "label": preds_list})
submission.to_csv('./submission.csv', index=False, header=False)
More examples.
Installation
The data sets in the examples can be found in Google Cloud Drive:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Close
Hashes for torch_text_similarity-1.0.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | e6a53bb503c4711ccf3fb6ba7c80f7d588f36f3dd56e919790daf7ebcf8af94e |
|
MD5 | c92bc077e08cbf34ff3e6ff24a5ac1f4 |
|
BLAKE2b-256 | 7773b156df7d183144b1c51be9d945ae27662418a2535d6c2c3f9629e603085b |