Skip to main content

Conditional random field in PyTorch

Project description

https://badge.fury.io/py/pytorch-crf.svg https://travis-ci.org/kmkurn/pytorch-crf.svg?branch=master https://coveralls.io/repos/github/kmkurn/pytorch-crf/badge.svg?branch=master https://cdn.rawgit.com/syl20bnr/spacemacs/442d025779da2f62fc86c2082703697714db6514/assets/spacemacs-badge.svg

Conditional random field in PyTorch.

Description

This package provides an implementation of conditional random field (CRF) in PyTorch. This implementation borrows mostly from AllenNLP CRF module with some modifications.

Requirements

  • Python 3.6

  • PyTorch 0.4.1

Installation

You can install with pip

pip install pytorch-crf

Or, you can install from Github directly

pip install git+https://github.com/kmkurn/pytorch-crf#egg=pytorch_crf

Examples

In the examples below, we will assume that these lines have been executed

>>> import torch
>>> from torchcrf import CRF
>>> seq_length, batch_size, num_tags = 3, 2, 5
>>> emissions = torch.randn(seq_length, batch_size, num_tags)
>>> tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long)  # (seq_length, batch_size)
>>> model = CRF(num_tags)

Computing log likelihood

>>> model(emissions, tags)
tensor(-12.7431, grad_fn=<SumBackward0>)

Computing log likelihood with mask

>>> mask = torch.tensor([[1, 1], [1, 1], [1, 0]], dtype=torch.uint8)  # (seq_length, batch_size)
>>> model(emissions, tags, mask=mask)
tensor(-10.8390, grad_fn=<SumBackward0>)

Decoding

>>> model.decode(emissions)
[[3, 1, 3], [0, 1, 0]]

Decoding with mask

>>> model.decode(emissions, mask=mask)
[[3, 1, 3], [0, 1]]

See tests/test_crf.py for more examples.

License

MIT. See LICENSE for details.

Contributing

Contributions are welcome! Please follow these instructions to setup dependencies and running the tests and linter. Make a pull request once your contribution is ready.

Installing dependencies

Make sure you setup a virtual environment with Python 3.6 and PyTorch installed. Then, install all the dependencies in requirements.txt file and install this package in development mode.

pip install -r requirements.txt
pip install -e .

Running tests

Run pytest in the project root directory.

Running linter

Run flake8 in the project root directory. This will also run mypy, thanks to flake8-mypy package.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

pytorch_crf-0.6.0-py3-none-any.whl (9.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_crf-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_crf-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 9.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.3

File hashes

Hashes for pytorch_crf-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4f1b53664a4b4af278c107d5d9b4a069aadb25536d6cbb2f257ede8dbb7ffa8f
MD5 a10495468d9e54fbfe8e11544fed048d
BLAKE2b-256 37b01166f2373a5d7bf825a1c616d4d1feaef1f7407d771d1c3e14bf51fadbd3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page