Conditional random field in PyTorch
Project description
# pytorch-crf
Conditional random field in [PyTorch](http://pytorch.org/).
## Description
This package provides an implementation of [conditional random field](https://en.wikipedia.org/wiki/Conditional_random_field) (CRF) in PyTorch. This implementation borrows mostly from [AllenNLP CRF module](https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional_random_field.py) with some modifications, most notably is using PyTorch’s fancy indexing instead of torch.gather and different way of doing Viterbi decoding.
## Requirements
Python 3.6
PyTorch 0.2
## Installation
You can install with pip
pip install pytorch-crf
Or, you can install from Github directly
pip install git+https://github.com/kmkurn/pytorch-crf#egg=pytorch_crf
## Examples
In the examples below, we will assume that these lines have been executed:
>>> import torch >>> from crf import CRF >>> seq_length, batch_size, num_tags = 3, 2, 5 >>> emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags)) >>> tags = torch.LongTensor([[0, 1], [2, 4], [3, 1]]) # (seq_length, batch_size) >>> model = CRF(num_tags) >>> # Initialize model parameters ... for p in model.parameters(): ... _ = torch.nn.init.uniform(p, -1, 1) ... >>>
### Forward computation
>>> model(emissions, tags) Variable containing: -10.0635 [torch.FloatTensor of size 1]
### Forward computation with mask
>>> mask = torch.autograd.Variable(torch.ByteTensor([[1, 1], [1, 1], [1, 0]])) # (seq_length, batch_size) >>> model(emissions, tags, mask=mask) Variable containing: -8.4981 [torch.FloatTensor of size 1]
### Decoding
>>> model.decode(emissions) [[3, 1, 3], [0, 1, 0]]
### Decoding with mask
>>> model.decode(emissions, mask=mask) [[3, 1, 3], [0, 1, 0]]
See tests/test_crf.py for more examples.
## License
MIT. See LICENSE.txt for details.
## Contributing
Contributions are welcome! Please follow these instructions to setup dependencies and running the tests and linter. Make a pull request once your contribution is ready.
### Installing dependencies
Make sure you setup a virtual environment with Python 3.6 and PyTorch installed. Then, install all the dependencies in requirements.txt file and install this package in development mode.
pip install -r requirements.txt pip install -e .
### Running tests
Run pytest in the project root directory.
### Running linter
Run flake8 in the project root directory. This will also run mypy, thanks to flake8-mypy package.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for pytorch_crf-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1dd8abd03a9b8b7a52225087e98acbefaa444c1a81b6471a0c7ebdea31e19e92 |
|
MD5 | 160a85c4ac3e2ae4d6854fff67a4481f |
|
BLAKE2b-256 | 1068cd94b6ac6e957aa0c2f8d4b072e4dec437be66896689054cc92f64307081 |