Skip to main content

A simple crf module written in pytorch. The implementation is based

Project description

PyTorch Text CRF

This package contains a simple wrapper for using conditional random fields(CRF). This code is based on the excellent Allen NLP implementation of CRF.


pip install pytorch-text-crf


from crf.crf import ConditionalRandomField

# Initilization
crf = ConditionalRandomField(n_tags,
                            idx2tag={0:"B-GEO", 1:"I-GEO", 2:"0"} # Index to tag mapping
# Likelihood estimation
log_likelihood = crf(logits, tags, mask)

# Decoding
best_tag_sequence = crf.best_viterbi_tag(logits, mask)
top_5_viterbi_tags = crf.viterbi_tags(logits, mask, top_k=5)

LSTM CRF Implementation

Refer to for a complete working implementation.

from crf.crf import ConditionalRandomField

class LSTMCRF:
    An Example implementation for using a CRF model on top of LSTM.
    def __init__(self):
        # Initilize the conditional CRF model
        self.crf = ConditionalRandomField(
            n_class, # Number of tags
            label_encoding="BIO", # Label encoding format
            idx2tag=idx2tag # Dict mapping index to a tag

    def forward(self, inputs, tags):
        logits = self.lstm(inputs) # logits dim:(batch_size, seq_length, num_tags)
        mask = inputs != "<pad token>" # mask for ignoring pad tokens. mask dim: (batch_size, seq_length)
        log_likelihood = self.crf(logits, tags, mask)
        loss = -log_likelihood # Log likelihood is not normalized (It is not divided by the batch size).

        # To obtain the best sequence using viterbi decoding
        best_tag_sequence = self.crf.best_viterbi_tag(logits, mask)

        # To obtain output similar to the lstm prediction we can use the below code
        class_probabilities = out * 0.0
        for i, instance_tags in enumerate(best_tag_sequence):
            for j, tag_id in enumerate(instance_tags[0][0]):
                class_probabilities[i, j, int(tag_id)] = 1
        return {"loss": loss, "class_probabilities": class_probabilities} 

 # Training
 lstm_crf = LSTMCRF()
 output = lstm_crf(sentences, tags)
 loss = output["loss"]

Project details

Release history Release notifications | RSS feed

This version


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for pytorch-text-crf, version 0.1
Filename, size File type Python version Upload date Hashes
Filename, size pytorch_text_crf-0.1-py3-none-any.whl (14.0 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size pytorch-text-crf-0.1.tar.gz (10.0 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page