Skip to main content

Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models

Project description

torch-text-similarity

Implementations of models and metrics for semantic text similarity. Includes fine-tuning and prediction of models. Thanks for the elegent implementations of @Andriy Mulyar, who has published a lot of useful codes.

Installation

Install with pip:

pip install torch-text-similarity

Use

Maps batches of sentence pairs to real-valued scores in the range [0,5]

import torch

from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset

learner = TextSimilarityLearner(batch_size=10,
                                model_name='web-bert-similarity',
                                loss_func=torch.nn.MSELoss(),
                                learning_rate=5e-5,
                                weight_decay=0,
                                device=torch.device('cuda:0'))

train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='./data/train.csv')

learner.load_train_data(train_dataset)
learner.train(epoch=1)

predictions = learner.predict([('The patient is sick.', 'Grass is green.'),
                               ('A prescription of acetaminophen 325 mg was given.', ' The patient was given Tylenol.')
                               ])

print(predictions)

Make submission to a semantic text similarity competition

import torch
import pandas as pd

from torch_text_similarity import TextSimilarityLearner
from torch_text_similarity.data import train_eval_sts_a_dataset

learner = TextSimilarityLearner(batch_size=10,
                                model_name='web-bert-similarity',
                                loss_func=torch.nn.MSELoss(),
                                learning_rate=5e-5,
                                weight_decay=0,
                                device=torch.device('cuda:0'))

train_dataset, eval_dataset = train_eval_sts_a_dataset(learner.bert_tokenizer, path='/home/temp/Data/kaggle/data/train.csv')

learner.load_train_data(train_dataset)
learner.train(epoch=1)

test_data = pd.read_csv('./data/test.csv')
preds_list = []
for i, row in test_data.iterrows():
    text_a = row['text_a']
    text_b = row['text_b']
    preds = learner.predict([(text_a, text_b)])[0]
    preds_list.append(preds)

submission = pd.DataFrame({"id": range(len(preds_list)), "label": preds_list})
submission.to_csv('./submission.csv', index=False, header=False)

More examples.

Installation

The data sets in the examples can be found in Google Cloud Drive:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_text_similarity-1.0.4.tar.gz (8.2 kB view details)

Uploaded Source

File details

Details for the file torch_text_similarity-1.0.4.tar.gz.

File metadata

File hashes

Hashes for torch_text_similarity-1.0.4.tar.gz
Algorithm Hash digest
SHA256 949ca10ae8f87d6a265378fea5d206b8f7a8df72d917f5737a04571f9c116f78
MD5 bbeaec10cadf08f45bf289ca6bee0500
BLAKE2b-256 9f9a4fd0e59b8075121fd34e02904e3ce8b6a946c866fa3b322d93acc8463644

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page