Skip to main content

A PyTorch implementation of the BI-LSTM-CRF model

Project description

A PyTorch implementation of the BI-LSTM-CRF model.

Features:

  • Compared with PyTorch BI-LSTM-CRF tutorial, following improvements are performed:
    • Full support for mini-batch computation
    • Full vectorized implementation. Specially, removing all loops in "score sentence" algorithm, which dramatically improve training performance
    • CUDA supported
    • Very simple APIs for CRF module
      • START/STOP tags are automatically added in CRF
      • A inner Linear Layer is included which transform from feature space to tag space
  • Specialized for NLP sequence tagging tasks
  • Easy to train your own sequence tagging models
  • MIT License

Installation

  • dependencies
  • install
    $ pip install bi-lstm-crf
    

Training

corpus

training

$ python -m bi_lstm_crf corpus_dir --model_dir "model_xxx"

training curve

import pandas as pd
import matplotlib.pyplot as plt

# the training losses are saved in the model_dir
df = pd.read_csv(".../model_dir/loss.csv")
df[["train_loss", "val_loss"]].ffill().plot(grid=True)
plt.show()

Prediction

from bi_lstm_crf.app import WordsTagger

model = WordsTagger(model_dir="xxx")
tags, sequences = model(["市领导到成都..."])  # CHAR-based model
print(tags)  
# [["B", "B", "I", "B", "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC", "B", "I", "B", "I"]]
print(sequences)
# [['市', '领导', '到', ('成都', 'LOC'), ...]]

# model([["市", "领导", "到", "成都", ...]])  # WORD-based model

CRF Module

The CRF module can be easily embeded into other models:

from bi_lstm_crf import CRF

# a BERT-CRF model for sequence tagging
class BertCrf(nn.Module):
    def __init__(self, ...):
        ...
        self.bert = BERT(...)
        self.crf = CRF(in_features, num_tags)

    def loss(self, xs, tags):
        features, = self.bert(xs)
        masks = xs.gt(0)
        loss = self.crf.loss(features, tags, masks)
        return loss

    def forward(self, xs):
        features, = self.bert(xs)
        masks = xs.gt(0)
        scores, tag_seq = self.crf(features, masks)
        return scores, tag_seq

References

  1. Zhiheng Huang, Wei Xu, and Kai Yu. 2015. Bidirectional LSTM-CRF Models for Sequence Tagging. arXiv:1508.01991.
  2. PyTorch tutorial ADVANCED: MAKING DYNAMIC DECISIONS AND THE BI-LSTM CRF

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bi-lstm-crf-0.2.0.tar.gz (11.1 kB view details)

Uploaded Source

File details

Details for the file bi-lstm-crf-0.2.0.tar.gz.

File metadata

  • Download URL: bi-lstm-crf-0.2.0.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0.post20200309 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for bi-lstm-crf-0.2.0.tar.gz
Algorithm Hash digest
SHA256 cbe90e68066cc56e38f0b16ccafa292d6dc7f18d41d5363339aaa605cd877167
MD5 28e7ab951cb57fe4407d4c287a7a6738
BLAKE2b-256 13f4cf0aabd5d3af7c8815bbce11038f9a9ec5b5c2e5d2f1a311e51b138e2313

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page