Skip to main content

No project description provided

Project description

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

Project details


Release history Release notifications | RSS feed

This version

0.5

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_struct-0.5.tar.gz (30.4 kB view details)

Uploaded Source

Built Distribution

torch_struct-0.5-py3-none-any.whl (34.7 kB view details)

Uploaded Python 3

File details

Details for the file torch_struct-0.5.tar.gz.

File metadata

  • Download URL: torch_struct-0.5.tar.gz
  • Upload date:
  • Size: 30.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for torch_struct-0.5.tar.gz
Algorithm Hash digest
SHA256 a40e234353f3b1a1df743100a781949b84b48877578340e77dc5066899c06495
MD5 c6658e4c1aff60598c346e5dc4afb5d6
BLAKE2b-256 f3e1bb2ad949a56c25014e1a7c0e2b0f5a7f670eff876fe733156afe2119eede

See more details on using hashes here.

File details

Details for the file torch_struct-0.5-py3-none-any.whl.

File metadata

  • Download URL: torch_struct-0.5-py3-none-any.whl
  • Upload date:
  • Size: 34.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.24.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for torch_struct-0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 9e4c4e5a2317d01cef47bfb0786b4ece6aebe1fa23f5f8a6a91887a3aab3d020
MD5 80940d05e3a14b32d63a70f95327c8a8
BLAKE2b-256 b28c775b7e141f11d509d59d0d2d801337ff3ad0203bc1a40335ea83e1161ba7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page