Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models
Project description
torch-text-similarity
Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models. Thanks for the elegent implementations of @Andriy Mulyar, who has published a lot of useful codes.
Installation
Install with pip:
pip install torch-text-similarity
Use
Maps batches of sentence pairs to real-valued scores in the range [0,5]
import torch
from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset
learner = TextSimilarityLearner(batch_size=10,
model_name='web-bert-similarity',
loss_func=torch.nn.MSELoss(),
learning_rate=5e-5,
weight_decay=0,
device=torch.device('cuda:0'))
train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='./data/train.csv')
learner.load_train_data(train_dataset)
learner.train(epoch=1)
predictions = learner.predict([('The patient is sick.', 'Grass is green.'),
('A prescription of acetaminophen 325 mg was given.', ' The patient was given Tylenol.')
])
print(predictions)
Make submission to a semantic text similarity competition
import torch
import pandas as pd
from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset
learner = TextSimilarityLearner(batch_size=10,
model_name='web-bert-similarity',
loss_func=torch.nn.MSELoss(),
learning_rate=5e-5,
weight_decay=0,
device=torch.device('cuda:0'))
train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='/home/temp/Data/kaggle/data/train.csv')
learner.load_train_data(train_dataset)
learner.train(epoch=1)
test_data = pd.read_csv('./data/test.csv')
preds_list = []
for i, row in test_data.iterrows():
text_a = row['text_a']
text_b = row['text_b']
preds = learner.predict([(text_a, text_b)])[0]
preds_list.append(preds)
submission = pd.DataFrame({"id": range(len(preds_list)), "label": preds_list})
submission.to_csv('./submission.csv', index=False, header=False)
More examples.
Installation
The data sets in the examples can be found in Google Cloud Drive:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torch_text_similarity-1.0.4.tar.gz
.
File metadata
- Download URL: torch_text_similarity-1.0.4.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: Python-urllib/3.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 949ca10ae8f87d6a265378fea5d206b8f7a8df72d917f5737a04571f9c116f78 |
|
MD5 | bbeaec10cadf08f45bf289ca6bee0500 |
|
BLAKE2b-256 | 9f9a4fd0e59b8075121fd34e02904e3ce8b6a946c866fa3b322d93acc8463644 |