No project description provided
Project description
Torch-Struct: Structured Prediction Library
A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.
- HMM / LinearChain-CRF
- HSMM / SemiMarkov-CRF
- Dependency Tree-CRF
- PCFG Binary Tree-CRF
- ...
Designed to be used as efficient batched layers in other PyTorch code.
Tutorial paper describing methodology.
Getting Started
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])
# Compute marginals
show(dist.marginals[0])
# Compute argmax
show(dist.argmax.detach()[0])
# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples
show(dist.sample((1,)).detach()[0, 0])
# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])
# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10)
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()
dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))
Library
Full docs: http://nlp.seas.harvard.edu/pytorch-struct/
Current distributions implemented:
- LinearChainCRF
- SemiMarkovCRF
- DependencyCRF
- NonProjectiveDependencyCRF
- TreeCRF
- NeuralPCFG / NeuralHMM
Each distribution includes:
- Argmax, sampling, entropy, partition, masking, log_probs, k-max
Extensions:
- Integration with
torchtext
,pytorch-transformers
,dgl
- Adapters for generative structured models (CFG / HMM / HSMM)
- Common tree structured parameterizations TreeLSTM / SpanLSTM
Low-level API:
Everything implemented through semiring dynamic programming.
- Log Marginals
- Max and MAP computation
- Sampling through specialized backprop
- Entropy and first-order semirings.
Examples
Citation
@misc{alex2020torchstruct,
title={Torch-Struct: Deep Structured Prediction Library},
author={Alexander M. Rush},
year={2020},
eprint={2002.00876},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_struct-0.5.tar.gz
(30.4 kB
view details)
Built Distribution
File details
Details for the file torch_struct-0.5.tar.gz
.
File metadata
- Download URL: torch_struct-0.5.tar.gz
- Upload date:
- Size: 30.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a40e234353f3b1a1df743100a781949b84b48877578340e77dc5066899c06495 |
|
MD5 | c6658e4c1aff60598c346e5dc4afb5d6 |
|
BLAKE2b-256 | f3e1bb2ad949a56c25014e1a7c0e2b0f5a7f670eff876fe733156afe2119eede |
File details
Details for the file torch_struct-0.5-py3-none-any.whl
.
File metadata
- Download URL: torch_struct-0.5-py3-none-any.whl
- Upload date:
- Size: 34.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9e4c4e5a2317d01cef47bfb0786b4ece6aebe1fa23f5f8a6a91887a3aab3d020 |
|
MD5 | 80940d05e3a14b32d63a70f95327c8a8 |
|
BLAKE2b-256 | b28c775b7e141f11d509d59d0d2d801337ff3ad0203bc1a40335ea83e1161ba7 |