Skip to main content

A simple crf module written in pytorch. The implementation is based https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional_random_field.py

Project description

PyTorch Text CRF

This package contains a simple wrapper for using conditional random fields(CRF). This code is based on the excellent Allen NLP implementation of CRF.

Installation

pip install pytorch-text-crf

Usage

from crf.crf import ConditionalRandomField

# Initilization
crf = ConditionalRandomField(n_tags,
                            label_encoding="BIO",
                            idx2tag={0:"B-GEO", 1:"I-GEO", 2:"0"} # Index to tag mapping
                            )
# Likelihood estimation
log_likelihood = crf(logits, tags, mask)

# Decoding
best_tag_sequence = crf.best_viterbi_tag(logits, mask)
top_5_viterbi_tags = crf.viterbi_tags(logits, mask, top_k=5)

LSTM CRF Implementation

Refer to https://github.com/iamsimha/pytorch-text-crf/blob/master/examples/pos_tagging/train.ipynb for a complete working implementation.

from crf.crf import ConditionalRandomField

class LSTMCRF:
    """
    An Example implementation for using a CRF model on top of LSTM.
    """
    def __init__(self):
        ...
        ...
        # Initilize the conditional CRF model
        self.crf = ConditionalRandomField(
            n_class, # Number of tags
            label_encoding="BIO", # Label encoding format
            idx2tag=idx2tag # Dict mapping index to a tag
        )

    def forward(self, inputs, tags):
        logits = self.lstm(inputs) # logits dim:(batch_size, seq_length, num_tags)
        mask = inputs != "<pad token>" # mask for ignoring pad tokens. mask dim: (batch_size, seq_length)
        log_likelihood = self.crf(logits, tags, mask)
        loss = -log_likelihood # Log likelihood is not normalized (It is not divided by the batch size).

        # To obtain the best sequence using viterbi decoding
        best_tag_sequence = self.crf.best_viterbi_tag(logits, mask)

        # To obtain output similar to the lstm prediction we can use the below code
        class_probabilities = out * 0.0
        for i, instance_tags in enumerate(best_tag_sequence):
            for j, tag_id in enumerate(instance_tags[0][0]):
                class_probabilities[i, j, int(tag_id)] = 1
        return {"loss": loss, "class_probabilities": class_probabilities} 

 # Training
 lstm_crf = LSTMCRF()
 output = lstm_crf(sentences, tags)
 loss = output["loss"]
 loss.backward()
 optimizer.step()

Project details


Release history Release notifications | RSS feed

This version

0.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-text-crf-0.1.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

pytorch_text_crf-0.1-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-text-crf-0.1.tar.gz.

File metadata

  • Download URL: pytorch-text-crf-0.1.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for pytorch-text-crf-0.1.tar.gz
Algorithm Hash digest
SHA256 bab37a78d9d2c0f62c1ae82f3a8466af9c64df0d01bea0eb4f357c5b3e0ebdc7
MD5 9d8847b47e2ae9a07cc6f1408b22aeb2
BLAKE2b-256 df651e6b577211ad2ffeee49352bf0b6b16ede29bae10c54f5a7f45aa423e958

See more details on using hashes here.

File details

Details for the file pytorch_text_crf-0.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_text_crf-0.1-py3-none-any.whl
  • Upload date:
  • Size: 14.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for pytorch_text_crf-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5000a5b68ed82fc8551362b6c0a6e25582553bccef4fe687e188de1b72ec7398
MD5 f1afca5693de15ae0a62b737dd1d1325
BLAKE2b-256 17aea8f42b712dc3d16953d088be32b5e8a5c1515acb5006a4509fb909f0344f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page