Skip to main content

A PyTorch implementation of Tree-structured Conditional Random Fields.

Project description

🌲 torch-treecrf

A PyTorch implementation of Tree-structured Conditional Random Fields.

Actions Coverage License PyPI Wheel Python Versions Python Implementations Source GitHub issues Changelog Downloads

🗺️ Overview

Conditional Random Fields (CRF) are a family of discriminative graphical learning models that can be used to model the dependencies between variables. The most common form of CRFs are Linear-chain CRF, where a prediction depends on an observed variable, as well as the prediction before and after it (the context). Linear-chain CRFs are widely used in Natural Language Processing.

$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^n{ \Psi_i(y_i, x_i) } \prod_{i=2}^n{ \Psi_{i-1,i}(y_{i-1}, y_i)} $$

In 2006, Tang et al.[1] introduced Tree-structured CRFs to model hierarchical relationships between predicted variables, allowing dependencies between a prediction variable and its parents and children.

$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^{n}{ \Psi_i(y_i, x_i) } \prod_{j \in \mathcal{N}(i)}{ \Psi_{j,i}(y_j, y_i)} $$

This package implements a generic Tree-structured CRF layer in PyTorch. The layer can be stacked on top of a linear layer to implement a proper Tree-structured CRF, or on any other kind of model producing emission scores in log-space for every class of each label. Computation of marginals is implemented using Belief Propagation[2], allowing for exact inference on trees[3]:

$$ \begin{aligned} P(y_i | X) & = \frac{1}{Z(X)} \Psi_i(y_i, x_i) & \underbrace{\prod_{j \in \mathcal{C}(i)}{\mu_{j \to i}(y_i)}} & & \underbrace{\prod_{j \in \mathcal{P}(i)}{\mu_{j \to i}(y_i)}} \ & = \frac1Z \Psi_i(y_i, x_i) & \alpha_i(y_i) & & \beta_i(y_i) \ \end{aligned} $$

where for every node $i$, the message from the parents $\mathcal{P}(i)$ and the children $\mathcal{C}(i)$ is computed recursively with the sum-product algorithm[4]:

$$ \begin{aligned} \forall j \in \mathcal{C}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{C}(j)}{\mu_{k \to j}(y_j)} } \ \forall j \in \mathcal{P}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{P}(j)}{\mu_{k \to j}(y_j)} } \ \end{aligned} $$

The implementation should be generic enough that any kind of Directed acyclic graph can be used as a label hierarchy, not just trees.

🔧 Installing

Install the torch-treecrf package directly from PyPi which hosts universal wheels that can be installed with pip:

$ pip install torch-treecrf

📋 Features

  • Encoding of directed graphs in an adjacency matrix, with $\mathcal{O}(1)$ retrieval of children and parents for any node, and $\mathcal{O}(N+E)$ storage.
  • Support for any acyclic hierarchy representable as a Directed Acyclic Graph and not just directed trees, allowing prediction of classes such as the Gene Ontology.
  • Multiclass output, provided all the target labels have the same number of classes: $Y \in \left\{ 0, .., C \right\}^L$.
  • Minibatch support, with vectorized computation of the messages $\alpha_i(y_i)$ and $\beta_i(y_i)$.

💡 Example

To create a Tree-structured CRF, you must first define the tree encoding the relationships between variables. Let's build a simple CRF for a root variable with two children:

First, define an adjacency matrix $M$ representing the hierarchy, such that $M_{i,j}$ is $1$ if $j$ is a parent of $i$:

adjacency = torch.tensor([
    [0, 0, 0],
    [1, 0, 0],
    [1, 0, 0]
])

Then, create the a CRF with the right number of features, depending on your feature space, like you would for a torch.nn.Linear module, to obtain a Torch model:

crf = torch_treecrf.TreeCRF(n_features=30, hierarchy=hierarchy)

If you wish to use the CRF layer only, use the TreeCRFLayer module, which expects and outputs an emission tensor of shape $(\star, C, L)$, where $\star$ is the minibatch size, $L$ the number of labels and $C$ the number of class per label.

💭 Feedback

⚠️ Issue Tracker

Found a bug ? Have an enhancement request ? Head over to the GitHub issue tracker if you need to report or ask something. If you are filing in on a bug, please include as much information as you can about the issue, and try to recreate the same bug in a simple, easily reproducible situation.

🏗️ Contributing

Contributions are more than welcome! See CONTRIBUTING.md for more details.

⚖️ License

This library is provided under the MIT License.

This library was developed by Martin Larralde during his PhD project at the European Molecular Biology Laboratory in the Zeller team.

📚 References

  • [1] Tang, Jie, Mingcai Hong, Juanzi Li, and Bangyong Liang. ‘Tree-Structured Conditional Random Fields for Semantic Annotation’. In The Semantic Web - ISWC 2006, edited by Isabel Cruz, Stefan Decker, Dean Allemang, Chris Preist, Daniel Schwabe, Peter Mika, Mike Uschold, and Lora M. Aroyo, 640–53. Lecture Notes in Computer Science. Berlin, Heidelberg: Springer, 2006. doi:10.1007/11926078_46.
  • [2] Pearl, Judea. ‘Reverend Bayes on Inference Engines: A Distributed Hierarchical Approach’. In Proceedings of the Second AAAI Conference on Artificial Intelligence, 133–136. AAAI’82. Pittsburgh, Pennsylvania: AAAI Press, 1982.
  • [3] Bach, Francis, and Guillaume Obozinski. ‘Sum Product Algorithm and Hidden Markov Model’, ENS Course Material, 2016. http://imagine.enpc.fr/%7Eobozinsg/teaching/mva_gm/lecture_notes/lecture7.pdf.
  • [4] Kschischang, Frank R., Brendan J. Frey, and Hans-Andrea Loeliger. ‘Factor Graphs and the Sum-Product Algorithm’. IEEE Transactions on Information Theory 47, no. 2 (February 2001): 498–519. doi:10.1109/18.910572.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-treecrf-0.2.0.tar.gz (12.2 kB view hashes)

Uploaded Source

Built Distribution

torch_treecrf-0.2.0-py3-none-any.whl (9.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page