No project description provided
Project description
Torch-Struct: Structured Prediction Library
A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.
- HMM / LinearChain-CRF
- HSMM / SemiMarkov-CRF
- Dependency Tree-CRF
- PCFG Binary Tree-CRF
- ...
Designed to be used as efficient batched layers in other PyTorch code.
Tutorial paper describing methodology.
Getting Started
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])
# Compute marginals
show(dist.marginals[0])
# Compute argmax
show(dist.argmax.detach()[0])
# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples
show(dist.sample((1,)).detach()[0, 0])
# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])
# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10)
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()
dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))
Library
Full docs: http://nlp.seas.harvard.edu/pytorch-struct/
Current distributions implemented:
- LinearChainCRF
- SemiMarkovCRF
- DependencyCRF
- NonProjectiveDependencyCRF
- TreeCRF
- NeuralPCFG / NeuralHMM
Each distribution includes:
- Argmax, sampling, entropy, partition, masking, log_probs, k-max
Extensions:
- Integration with
torchtext,pytorch-transformers,dgl - Adapters for generative structured models (CFG / HMM / HSMM)
- Common tree structured parameterizations TreeLSTM / SpanLSTM
Low-level API:
Everything implemented through semiring dynamic programming.
- Log Marginals
- Max and MAP computation
- Sampling through specialized backprop
- Entropy and first-order semirings.
Examples
Citation
@misc{alex2020torchstruct,
title={Torch-Struct: Deep Structured Prediction Library},
author={Alexander M. Rush},
year={2020},
eprint={2002.00876},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_struct-0.5.tar.gz.
File metadata
- Download URL: torch_struct-0.5.tar.gz
- Upload date:
- Size: 30.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a40e234353f3b1a1df743100a781949b84b48877578340e77dc5066899c06495
|
|
| MD5 |
c6658e4c1aff60598c346e5dc4afb5d6
|
|
| BLAKE2b-256 |
f3e1bb2ad949a56c25014e1a7c0e2b0f5a7f670eff876fe733156afe2119eede
|
File details
Details for the file torch_struct-0.5-py3-none-any.whl.
File metadata
- Download URL: torch_struct-0.5-py3-none-any.whl
- Upload date:
- Size: 34.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9e4c4e5a2317d01cef47bfb0786b4ece6aebe1fa23f5f8a6a91887a3aab3d020
|
|
| MD5 |
80940d05e3a14b32d63a70f95327c8a8
|
|
| BLAKE2b-256 |
b28c775b7e141f11d509d59d0d2d801337ff3ad0203bc1a40335ea83e1161ba7
|