Skip to main content

A PyTorch implementation of Tree-structured Conditional Random Fields.

Project description

🌲 torch-treecrf

A PyTorch implementation of Tree-structured Conditional Random Fields.

Actions Coverage License PyPI Wheel Python Versions Python Implementations Source GitHub issues Changelog Downloads

🗺️ Overview

Conditional Random Fields (CRF) are a family of discriminative graphical learning models that can be used to model the dependencies between variables. The most common form of CRFs are Linear-chain CRF, where a prediction depends on an observed variable, as well as the prediction before and after it (the context). Linear-chain CRFs are widely used in Natural Language Processing.

$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^n{ \Psi_i(y_i, x_i) } \prod_{i=2}^n{ \Psi_{i-1,i}(y_{i-1}, y_i)} $$

In 2006, Tang et al.[1] introduced Tree-structured CRFs to model hierarchical relationships between predicted variables, allowing dependencies between a prediction variable and its parents and children.

$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^{n}{ \Psi_i(y_i, x_i) } \prod_{j \in \mathcal{N}(i)}{ \Psi_{j,i}(y_j, y_i)} $$

This package implements a generic Tree-structured CRF layer in PyTorch. The layer can be stacked on top of a linear layer to implement a proper Tree-structured CRF, or on any other kind of model producing emission scores in log-space for every class of each label. Computation of marginals is implemented using Belief Propagation[2], allowing for exact inference on trees[3]:

$$ \begin{aligned} P(y_i | X) & = \frac{1}{Z(X)} \Psi_i(y_i, x_i) & \underbrace{\prod_{j \in \mathcal{C}(i)}{\mu_{j \to i}(y_i)}} & & \underbrace{\prod_{j \in \mathcal{P}(i)}{\mu_{j \to i}(y_i)}} \ & = \frac1Z \Psi_i(y_i, x_i) & \alpha_i(y_i) & & \beta_i(y_i) \ \end{aligned} $$

where for every node $i$, the message from the parents $\mathcal{P}(i)$ and the children $\mathcal{C}(i)$ is computed recursively with the sum-product algorithm[4]:

$$ \begin{aligned} \forall j \in \mathcal{C}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{C}(j)}{\mu_{k \to j}(y_j)} } \ \forall j \in \mathcal{P}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{P}(j)}{\mu_{k \to j}(y_j)} } \ \end{aligned} $$

The implementation should be generic enough that any kind of Directed acyclic graph can be used as a label hierarchy, not just trees.

🔧 Installing

Install the torch-treecrf package directly from PyPi which hosts universal wheels that can be installed with pip:

$ pip install torch-treecrf

📋 Features

  • Encoding of directed graphs in an adjacency matrix, with $\mathcal{O}(1)$ retrieval of children and parents for any node, and $\mathcal{O}(N+E)$ storage.
  • Support for any acyclic hierarchy representable as a Directed Acyclic Graph and not just directed trees, allowing prediction of classes such as the Gene Ontology.
  • Multiclass output, provided all the target labels have the same number of classes: $Y \in \left\{ 0, .., C \right\}^L$.
  • Minibatch support, with vectorized computation of the messages $\alpha_i(y_i)$ and $\beta_i(y_i)$.

💡 Example

To create a Tree-structured CRF, you must first define the tree encoding the relationships between variables. Let's build a simple CRF for a root variable with two children:

First, define an adjacency matrix $M$ representing the hierarchy, such that $M_{i,j}$ is $1$ if $j$ is a parent of $i$:

adjacency = torch.tensor([
    [0, 0, 0],
    [1, 0, 0],
    [1, 0, 0]
])

Then, create the a CRF with the right number of features, depending on your feature space, like you would for a torch.nn.Linear module, to obtain a Torch model:

crf = torch_treecrf.TreeCRF(n_features=30, hierarchy=hierarchy)

If you wish to use the CRF layer only, use the TreeCRFLayer module, which expects and outputs an emission tensor of shape $(\star, C, L)$, where $\star$ is the minibatch size, $L$ the number of labels and $C$ the number of class per label.

💭 Feedback

⚠️ Issue Tracker

Found a bug ? Have an enhancement request ? Head over to the GitHub issue tracker if you need to report or ask something. If you are filing in on a bug, please include as much information as you can about the issue, and try to recreate the same bug in a simple, easily reproducible situation.

🏗️ Contributing

Contributions are more than welcome! See CONTRIBUTING.md for more details.

⚖️ License

This library is provided under the MIT License.

This library was developed by Martin Larralde during his PhD project at the European Molecular Biology Laboratory in the Zeller team.

📚 References

  • [1] Tang, Jie, Mingcai Hong, Juanzi Li, and Bangyong Liang. ‘Tree-Structured Conditional Random Fields for Semantic Annotation’. In The Semantic Web - ISWC 2006, edited by Isabel Cruz, Stefan Decker, Dean Allemang, Chris Preist, Daniel Schwabe, Peter Mika, Mike Uschold, and Lora M. Aroyo, 640–53. Lecture Notes in Computer Science. Berlin, Heidelberg: Springer, 2006. doi:10.1007/11926078_46.
  • [2] Pearl, Judea. ‘Reverend Bayes on Inference Engines: A Distributed Hierarchical Approach’. In Proceedings of the Second AAAI Conference on Artificial Intelligence, 133–136. AAAI’82. Pittsburgh, Pennsylvania: AAAI Press, 1982.
  • [3] Bach, Francis, and Guillaume Obozinski. ‘Sum Product Algorithm and Hidden Markov Model’, ENS Course Material, 2016. http://imagine.enpc.fr/%7Eobozinsg/teaching/mva_gm/lecture_notes/lecture7.pdf.
  • [4] Kschischang, Frank R., Brendan J. Frey, and Hans-Andrea Loeliger. ‘Factor Graphs and the Sum-Product Algorithm’. IEEE Transactions on Information Theory 47, no. 2 (February 2001): 498–519. doi:10.1109/18.910572.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-treecrf-0.2.0.tar.gz (12.2 kB view details)

Uploaded Source

Built Distribution

torch_treecrf-0.2.0-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file torch-treecrf-0.2.0.tar.gz.

File metadata

  • Download URL: torch-treecrf-0.2.0.tar.gz
  • Upload date:
  • Size: 12.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for torch-treecrf-0.2.0.tar.gz
Algorithm Hash digest
SHA256 323fe782221fa7e6d9d41bd0fc9545feed8e1b10b4d547c4749a3ec92fbab532
MD5 8153fd3529fdb9a6e03c0879d5324282
BLAKE2b-256 f8a02a7bfd00ee7f7be32eb816f318a795a5e9b7c246cab4861b60027299762a

See more details on using hashes here.

File details

Details for the file torch_treecrf-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_treecrf-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fb2b9672ec89de36ea09f7e67333e3aa671dd0ac068c16222862595dc6372652
MD5 7f592c44d9581c04960295f8f2240d75
BLAKE2b-256 03b9f485613f1a431edf14b738bc5f5054f94381a2c1e4710ca0c0b73b72cd51

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page