PyTorch module for differentiable BLEU score computation that supports end-to-end training
Project description
BLEU-Torch: Fast Differentiable BLEU Scores for PyTorch
A fully differentiable PyTorch implementation of BLEU scores optimized for training neural networks. Unlike traditional BLEU implementations that work with discrete tokens, bleu-torch operates directly on logits, making it perfect for use as a differentiable loss function in neural text generation models.
🚀 Key Features
- ⚡ Fully Vectorized: Batch processing with no Python loops
- 🔥 GPU Accelerated: Native PyTorch tensors with CUDA support
- 📈 Differentiable: Can be used as a loss function for training
- 🎯 Multiple Loss Types: Complement (1-BLEU) and log (-log(BLEU)) loss functions
- 📊 Proper Loss Bounds: Loss ∈ [0, 1] for complement loss, with 0 = perfect match
- 🧪 Thoroughly Tested: 17 comprehensive tests including overfit validation
- 🚄 High Performance: Efficient implementation for large-scale training
- 🌡️ Temperature Control: Gumbel Softmax with configurable temperature
📦 Installation
pip install bleu-torch
Or install from source:
git clone https://github.com/Ghost---Shadow/bleu-torch.git
cd bleu-torch
pip install -e .
💡 Quick Start
Basic Usage
import torch
from bleu_torch import DifferentiableBLEUModule, DifferentiableBLEULoss
# Initialize BLEU module
vocab_size = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bleu_module = DifferentiableBLEUModule(vocab_size=vocab_size, temperature=0.1).to(device)
loss_fn = DifferentiableBLEULoss(bleu_module, loss_type="complement").to(device)
# Training scenario with logits
logits = torch.randn(8, vocab_size, requires_grad=True, device=device) # (seq_len, vocab_size)
references = [
torch.randint(0, vocab_size, (8,), device=device), # Reference 1
torch.randint(0, vocab_size, (10,), device=device), # Reference 2
torch.randint(0, vocab_size, (6,), device=device), # Reference 3
]
# Compute loss and backpropagate
loss = loss_fn(logits, references)
loss.backward()
print(f"BLEU Loss: {loss.item():.4f}")
print(f"Gradient norm: {logits.grad.norm().item():.6f}")
Using as a Loss Function
# Perfect for training neural networks!
for batch in dataloader:
logits = model(batch['input_ids']) # (seq_len, vocab_size)
# BLEU loss is differentiable and ready for backprop
loss = loss_fn(logits, batch['references'])
loss.backward()
optimizer.step()
Working with Neural Models
import torch.nn as nn
from bleu_torch import DifferentiableBLEUModule, DifferentiableBLEULoss
class MyLanguageModel(nn.Module):
def __init__(self, vocab_size, hidden_dim=512):
super().__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8),
num_layers=6
)
self.output_proj = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.transformer(x)
return self.output_proj(x)
# Training setup
model = MyLanguageModel(vocab_size=10000)
bleu_module = DifferentiableBLEUModule(vocab_size=10000, temperature=0.1)
bleu_loss = DifferentiableBLEULoss(bleu_module, loss_type="complement")
# Train with BLEU loss
for batch in dataloader:
logits = model(batch['input_ids'])
loss = bleu_loss(logits, batch['references'])
loss.backward()
optimizer.step()
📋 API Reference
DifferentiableBLEUModule
Main class for computing differentiable BLEU scores.
bleu_module = DifferentiableBLEUModule(vocab_size: int, max_n: int = 4,
temperature: float = 1.0, smoothing: float = 1e-10)
Parameters
vocab_size: Size of the vocabularymax_n: Maximum n-gram order (default: 4 for BLEU-4)temperature: Temperature for Gumbel Softmax during training (default: 1.0)smoothing: Small value for numerical stability (default: 1e-10)
Methods
forward(candidate_input, reference_ids_list)
- Computes BLEU score for a single candidate
candidate_input: Either(seq_len, vocab_size)logits or(seq_len,)token IDsreference_ids_list: List of reference token ID tensors- Returns: BLEU score tensor (scalar)
batch_forward(candidate_inputs, reference_ids_batch)
- Computes BLEU scores for a batch of candidates
- Returns: Tensor of BLEU scores with shape
(batch_size,)
DifferentiableBLEULoss
Loss function wrapper with proper bounds for training.
loss_fn = DifferentiableBLEULoss(bleu_module: DifferentiableBLEUModule,
loss_type: str = "complement")
Parameters
bleu_module: DifferentiableBLEUModule instanceloss_type: Loss type -"complement"(1-BLEU) or"log"(-log(BLEU))
Methods
forward(candidate_input, reference_ids_list)
- Computes BLEU-based loss with guaranteed minimum of 0
- Returns differentiable loss tensor
batch_forward(candidate_inputs, reference_ids_batch)
- Computes batch loss for multiple candidates
- Returns: Mean loss across the batch
🎯 Loss Function Details
The BLEU loss is designed for training neural networks:
# Single loss computation: loss ∈ [0, 1] for complement loss
loss = loss_fn(logits, references)
# Different loss types
complement_loss = DifferentiableBLEULoss(bleu_module, "complement") # loss = 1 - BLEU
log_loss = DifferentiableBLEULoss(bleu_module, "log") # loss = -log(BLEU)
Loss Properties:
- ✅ Differentiable: Use with any PyTorch optimizer
- ✅ Proper Bounds: Always ≥ 0, with 0 = perfect match
- ✅ Intuitive: Lower loss = better BLEU scores
- ✅ Validated: Tested with overfit experiments reaching ~0.0 loss
Requirements
- Python >= 3.8
- PyTorch >= 1.9.0
- NumPy >= 1.19.0
Testing
Run the comprehensive test suite:
python -m pytest tests/ -v
# Or run directly:
python tests/test_bleu_torch.py
The test suite includes:
- 7 unit tests for basic functionality
- 5 comprehensive overfitting tests
- 5 training scenario tests
- GPU/CPU compatibility tests
License
MIT License - see LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bleu_torch-1.0.0.tar.gz.
File metadata
- Download URL: bleu_torch-1.0.0.tar.gz
- Upload date:
- Size: 13.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
59ce9d666344c932ccdae44a3afb917814e1b3c8051de0018f08d52b80b3601d
|
|
| MD5 |
7b2a5929384030a5b89c5488fcf36a7d
|
|
| BLAKE2b-256 |
fbc624ddb564fcc7cd71fce76add1eccd52d7459e25b085074a67d6e24482511
|
File details
Details for the file bleu_torch-1.0.0-py3-none-any.whl.
File metadata
- Download URL: bleu_torch-1.0.0-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a29e7cf153f9cce3d89fc5840c42211b3aa17a5aeb0b2a4c907ca52f2b3e580
|
|
| MD5 |
9dd4fc01262e1b5416f9f9ff6cc600d8
|
|
| BLAKE2b-256 |
4759b63c6de529f7b7ae1431f9a9b3c1b3a7f0af67736b9e186483dae8e31316
|