A simple crf module written in pytorch. The implementation is based https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional_random_field.py
Project description
PyTorch Text CRF
This package contains a simple wrapper for using conditional random fields(CRF). This code is based on the excellent Allen NLP implementation of CRF.
Installation
pip install pytorch-text-crf
Usage
from crf.crf import ConditionalRandomField
# Initilization
crf = ConditionalRandomField(n_tags,
label_encoding="BIO",
idx2tag={0:"B-GEO", 1:"I-GEO", 2:"0"} # Index to tag mapping
)
# Likelihood estimation
log_likelihood = crf(logits, tags, mask)
# Decoding
best_tag_sequence = crf.best_viterbi_tag(logits, mask)
top_5_viterbi_tags = crf.viterbi_tags(logits, mask, top_k=5)
LSTM CRF Implementation
Refer to https://github.com/iamsimha/pytorch-text-crf/blob/master/examples/pos_tagging/train.ipynb for a complete working implementation.
from crf.crf import ConditionalRandomField
class LSTMCRF:
"""
An Example implementation for using a CRF model on top of LSTM.
"""
def __init__(self):
...
...
# Initilize the conditional CRF model
self.crf = ConditionalRandomField(
n_class, # Number of tags
label_encoding="BIO", # Label encoding format
idx2tag=idx2tag # Dict mapping index to a tag
)
def forward(self, inputs, tags):
logits = self.lstm(inputs) # logits dim:(batch_size, seq_length, num_tags)
mask = inputs != "<pad token>" # mask for ignoring pad tokens. mask dim: (batch_size, seq_length)
log_likelihood = self.crf(logits, tags, mask)
loss = -log_likelihood # Log likelihood is not normalized (It is not divided by the batch size).
# To obtain the best sequence using viterbi decoding
best_tag_sequence = self.crf.best_viterbi_tag(logits, mask)
# To obtain output similar to the lstm prediction we can use the below code
class_probabilities = out * 0.0
for i, instance_tags in enumerate(best_tag_sequence):
for j, tag_id in enumerate(instance_tags[0][0]):
class_probabilities[i, j, int(tag_id)] = 1
return {"loss": loss, "class_probabilities": class_probabilities}
# Training
lstm_crf = LSTMCRF()
output = lstm_crf(sentences, tags)
loss = output["loss"]
loss.backward()
optimizer.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-text-crf-0.1.tar.gz
(10.0 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_text_crf-0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5000a5b68ed82fc8551362b6c0a6e25582553bccef4fe687e188de1b72ec7398 |
|
MD5 | f1afca5693de15ae0a62b737dd1d1325 |
|
BLAKE2b-256 | 17aea8f42b712dc3d16953d088be32b5e8a5c1515acb5006a4509fb909f0344f |