Skip to main content

PyTorch porting of BLEURT

Project description

bleurt-pytorch

Use BLEURT models in native PyTorch with Transformers.

Getting started

Install with:

pip install git+https://github.com/lucadiliello/bleurt-pytorch.git

Now load your favourite model with:

import torch
from bleurt_pytorch import BleurtConfig, BleurtForSequenceClassification, BleurtTokenizer

config = BleurtConfig.from_pretrained('lucadiliello/BLEURT-20-D12')
model = BleurtForSequenceClassification.from_pretrained('lucadiliello/BLEURT-20-D12')
tokenizer = BleurtTokenizer.from_pretrained('lucadiliello/BLEURT-20-D12')

references = ["a bird chirps by the window", "this is a random sentence"]
candidates = ["a bird chirps by the window", "this looks like a random sentence"]

model.eval()
with torch.no_grad():
    inputs = tokenizer(references, candidates, padding='longest', return_tensors='pt')
    res = model(**inputs).logits.flatten().tolist()
print(res)
# [0.9604414105415344, 0.8080050349235535]

You can find all BLUERT models adapted for PyTorch here. The recommended model is lucadiliello/BLEURT-20, however this model is very large and may require too much resources. BLEURT-20-D12 is smaller but works well enough for most comparisons.

Credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bleurt-pytorch-0.0.1.tar.gz (19.8 kB view hashes)

Uploaded Source

Built Distribution

bleurt_pytorch-0.0.1-py3-none-any.whl (22.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page