Skip to main content

Linear Conditional Random Field implementation in Pytorch

Project description

Linear CRF

Description

This repository hosts my Pytorch implementation of the Linear Conditional Random Field model, which is available via PyPi. There are a number of similar packages, however at this point I can say this one is faster (or equivalent), as I benchmarked the alternatives.

Origin story

About a year ago I needed this for a project I was working on, and stumbled upon the de facto official implementation (at least by star count), which you can check here. Now although the code was already quite good I decided to optimize it for my needs, and ended up with significant gains. I decided to open a PR in order to share, but quickly realized two thjings. The first was that I had commited way too much, and should have consulted with the author first. The second was that some changes actually went against the wants of the author, as they didn't meet his readability standard. Anyway, long story short I got demotivated, then forgot about it, and the PR of shame is still open >_<"

However the story does not stop here ! As I needed to use a CRF for another project recently, I decided to clean up my code, and ended up optimizing it even further. As I like how it looks, but learned my lesson, I decided to release it, and voila !

Installation

With Python 3.6 or higher:

pip install linear-crf-torch

The model is not compatible with Pytorch versions older then 1.3, as I use features added from that version. The required changes are minimal, si I don't plan to include them.

Usage

The example below shows the basic usage:

import torch
from linear_crf import LinearCRF

seq_length = 3
batch_size = 4
num_tags = 5

model = LinearCRF(num_tags)

emissions = torch.randn(seq_length, batch_size, num_tags)
labels = torch.randint(num_tags, (seq_length, batch_size))

# Compute the average negative log-likelihood
loss = model(emissions, labels)
print(f"loss: {loss:.4f}")

# Viterbi decoding
tags = model.decode(emissions)
for i, x in enumerate(tags):
    print(f"tags for sequence {i}: {x}") 

A couple caveats:

  • I followed the Pytorch convention of setting the batch dimension after the sequence one, but you can set batch_first=True in the constructor if you wish to pass data the other way around.
  • Unlike similar packages, no input validation is performed - I think the documentation should be enough to avoid any bugs.
  • Using the impossible_starts, impossible_transitions and impossible_ends parameters in the constructor, you can make it impossible for certain tags to appear at the start or end of the sequences, and make transitions from one tag to another impossible.
  • In the forward pass, the loss is normalized by the number of non-masked elements. It doesn't make sense to normalize in an other way, neither does directly using the sum.
  • Gradients are disabled during decoding.
  • Masking is only supported from the right, meaning if you mask the left part of a sentence (e.g. [0, 0, 1, 1]) the computations will be incorrect.

License

MIT

Benchmarks

Over here.

Contributing

All help is welcome, as long as you open an issue beforehand to talk about it :)

Reference

Conditional Random Fields: Probabilistic Modelsfor Segmenting and Labeling Sequence Data by John Lafferty, Andrew McCallum and Fernando C.N. Pereira

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

linear-crf-torch-0.0.1.tar.gz (5.1 kB view details)

Uploaded Source

Built Distribution

linear_crf_torch-0.0.1-py3-none-any.whl (5.9 kB view details)

Uploaded Python 3

File details

Details for the file linear-crf-torch-0.0.1.tar.gz.

File metadata

  • Download URL: linear-crf-torch-0.0.1.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.7.9

File hashes

Hashes for linear-crf-torch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 6f68e85cd626650f0bbf3a01a85eafe37b5fd0511ec186da821966895203c3ba
MD5 3e4fe8f94e933cb93427b3b0773cf05b
BLAKE2b-256 3e3533797ae74ecb76ec70d202b6df99ab17f382f6d759e8c23131e9c15ca4f6

See more details on using hashes here.

File details

Details for the file linear_crf_torch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: linear_crf_torch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 5.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.7.9

File hashes

Hashes for linear_crf_torch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 62d11aa72a68b84965ebd2baeeced01dd1b106c667d3d614edc43092bd9927d8
MD5 c138b872b8693c86d45dd362e5bfe745
BLAKE2b-256 ab17509626d4fb6c14386476a1fd901802414b39a3fb987312de7dc4184b9d2a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page