Skip to main content

A PyTorch implementation of the BI-LSTM-CRF model

Project description

A PyTorch implementation of the BI-LSTM-CRF model.

Features:

  • Compared with PyTorch BI-LSTM-CRF tutorial, following improvements are performed:
    • Full support for mini-batch computation
    • Full vectorized implementation. Specially, removing all loops in "score sentence" algorithm, which dramatically improve training performance
    • CUDA supported
    • Very simple APIs for CRF module
      • START/STOP tags are automatically added in CRF
      • A inner Linear Layer is included which transform from feature space to tag space
  • Specialized for NLP sequence tagging tasks
  • Easy to train your own sequence tagging models
  • MIT License

Installation

  • dependencies
  • install
    $ pip install bi-lstm-crf
    

Training

corpus

training

$ python -m bi_lstm_crf corpus_dir --model_dir "model_xxx"

training curve

import pandas as pd
import matplotlib.pyplot as plt

# the training losses are saved in the model_dir
df = pd.read_csv(".../model_dir/loss.csv")
df[["train_loss", "val_loss"]].ffill().plot(grid=True)
plt.show()

Prediction

from bi_lstm_crf.app import WordsTagger

model = WordsTagger(model_dir="xxx")
tags, sequences = model(["市领导到成都..."])  # CHAR-based model
print(tags)  
# [["B", "B", "I", "B", "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC", "B", "I", "B", "I"]]
print(sequences)
# [['市', '领导', '到', ('成都', 'LOC'), ...]]

# model([["市", "领导", "到", "成都", ...]])  # WORD-based model

CRF Module

The CRF module can be easily embeded into other models:

from bi_lstm_crf import CRF

# a BERT-CRF model for sequence tagging
class BertCrf(nn.Module):
    def __init__(self, ...):
        ...
        self.bert = BERT(...)
        self.crf = CRF(in_features, num_tags)

    def loss(self, xs, tags):
        features, = self.bert(xs)
        masks = xs.gt(0)
        loss = self.crf.loss(features, tags, masks)
        return loss

    def forward(self, xs):
        features, = self.bert(xs)
        masks = xs.gt(0)
        scores, tag_seq = self.crf(features, masks)
        return scores, tag_seq

References

  1. Zhiheng Huang, Wei Xu, and Kai Yu. 2015. Bidirectional LSTM-CRF Models for Sequence Tagging. arXiv:1508.01991.
  2. PyTorch tutorial ADVANCED: MAKING DYNAMIC DECISIONS AND THE BI-LSTM CRF

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bi-lstm-crf-0.2.0.tar.gz (11.1 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page