A simple crf module written in pytorch. The implementation is based https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional_random_field.py
Project description
PyTorch Text CRF
This package contains a simple wrapper for using conditional random fields(CRF). This code is based on the excellent Allen NLP implementation of CRF.
Installation
pip install pytorch-text-crf
Usage
from crf.crf import ConditionalRandomField
# Initilization
crf = ConditionalRandomField(n_tags,
label_encoding="BIO",
idx2tag={0:"B-GEO", 1:"I-GEO", 2:"0"} # Index to tag mapping
)
# Likelihood estimation
log_likelihood = crf(logits, tags, mask)
# Decoding
best_tag_sequence = crf.best_viterbi_tag(logits, mask)
top_5_viterbi_tags = crf.viterbi_tags(logits, mask, top_k=5)
LSTM CRF Implementation
Refer to https://github.com/iamsimha/pytorch-text-crf/blob/master/examples/pos_tagging/train.ipynb for a complete working implementation.
from crf.crf import ConditionalRandomField
class LSTMCRF:
"""
An Example implementation for using a CRF model on top of LSTM.
"""
def __init__(self):
...
...
# Initilize the conditional CRF model
self.crf = ConditionalRandomField(
n_class, # Number of tags
label_encoding="BIO", # Label encoding format
idx2tag=idx2tag # Dict mapping index to a tag
)
def forward(self, inputs, tags):
logits = self.lstm(inputs) # logits dim:(batch_size, seq_length, num_tags)
mask = inputs != "<pad token>" # mask for ignoring pad tokens. mask dim: (batch_size, seq_length)
log_likelihood = self.crf(logits, tags, mask)
loss = -log_likelihood # Log likelihood is not normalized (It is not divided by the batch size).
# To obtain the best sequence using viterbi decoding
best_tag_sequence = self.crf.best_viterbi_tag(logits, mask)
# To obtain output similar to the lstm prediction we can use the below code
class_probabilities = out * 0.0
for i, instance_tags in enumerate(best_tag_sequence):
for j, tag_id in enumerate(instance_tags[0][0]):
class_probabilities[i, j, int(tag_id)] = 1
return {"loss": loss, "class_probabilities": class_probabilities}
# Training
lstm_crf = LSTMCRF()
output = lstm_crf(sentences, tags)
loss = output["loss"]
loss.backward()
optimizer.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch-text-crf-0.1.tar.gz.
File metadata
- Download URL: pytorch-text-crf-0.1.tar.gz
- Upload date:
- Size: 10.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bab37a78d9d2c0f62c1ae82f3a8466af9c64df0d01bea0eb4f357c5b3e0ebdc7
|
|
| MD5 |
9d8847b47e2ae9a07cc6f1408b22aeb2
|
|
| BLAKE2b-256 |
df651e6b577211ad2ffeee49352bf0b6b16ede29bae10c54f5a7f45aa423e958
|
File details
Details for the file pytorch_text_crf-0.1-py3-none-any.whl.
File metadata
- Download URL: pytorch_text_crf-0.1-py3-none-any.whl
- Upload date:
- Size: 14.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5000a5b68ed82fc8551362b6c0a6e25582553bccef4fe687e188de1b72ec7398
|
|
| MD5 |
f1afca5693de15ae0a62b737dd1d1325
|
|
| BLAKE2b-256 |
17aea8f42b712dc3d16953d088be32b5e8a5c1515acb5006a4509fb909f0344f
|