A PyTorch implementation of Tree-structured Conditional Random Fields.
Project description
🌲 torch-treecrf
A PyTorch implementation of Tree-structured Conditional Random Fields.
🗺️ Overview
Conditional Random Fields (CRF) are a family of discriminative graphical learning models that can be used to model the dependencies between variables. The most common form of CRFs are Linear-chain CRF, where a prediction depends on an observed variable, as well as the prediction before and after it (the context). Linear-chain CRFs are widely used in Natural Language Processing.
$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^n{ \Psi_i(y_i, x_i) } \prod_{i=2}^n{ \Psi_{i-1,i}(y_{i-1}, y_i)} $$
In 2006, Tang et al.[1] introduced Tree-structured CRFs to model hierarchical relationships between predicted variables, allowing dependencies between a prediction variable and its parents and children.
$$ P(Y | X) = \frac{1}{Z(X)} \prod_{i=1}^{n}{ \Psi_i(y_i, x_i) } \prod_{j \in \mathcal{N}(i)}{ \Psi_{j,i}(y_j, y_i)} $$
This package implements a generic Tree-structured CRF layer in PyTorch. The layer can be stacked on top of a linear layer to implement a proper Tree-structured CRF, or on any other kind of model producing emission scores in log-space for every class of each label. Computation of marginals is implemented using Belief Propagation[2], allowing for exact inference on trees[3]:
$$ \begin{aligned} P(y_i | X) & = \frac{1}{Z(X)} \Psi_i(y_i, x_i) & \underbrace{\prod_{j \in \mathcal{C}(i)}{\mu_{j \to i}(y_i)}} & & \underbrace{\prod_{j \in \mathcal{P}(i)}{\mu_{j \to i}(y_i)}} \ & = \frac1Z \Psi_i(y_i, x_i) & \alpha_i(y_i) & & \beta_i(y_i) \ \end{aligned} $$
where for every node $i$, the message from the parents $\mathcal{P}(i)$ and the children $\mathcal{C}(i)$ is computed recursively with the sum-product algorithm[4]:
$$ \begin{aligned} \forall j \in \mathcal{C}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{C}(j)}{\mu_{k \to j}(y_j)} } \ \forall j \in \mathcal{P}(i), \mu_{j \to i}(y_i) = \sum_{y_j}{ \Psi_{i,j}(y_i, y_j) \Psi_j(y_j, x_j) \prod_{k \in \mathcal{P}(j)}{\mu_{k \to j}(y_j)} } \ \end{aligned} $$
The implementation should be generic enough that any kind of Directed acyclic graph can be used as a label hierarchy, not just trees.
🔧 Installing
Install the torch-treecrf
package directly from PyPi
which hosts universal wheels that can be installed with pip
:
$ pip install torch-treecrf
📋 Features
- Encoding of directed graphs in an adjacency matrix, with $\mathcal{O}(1)$ retrieval of children and parents for any node, and $\mathcal{O}(N+E)$ storage.
- Support for any acyclic hierarchy representable as a Directed Acyclic Graph and not just directed trees, allowing prediction of classes such as the Gene Ontology.
- Multiclass output, provided all the target labels have the same number of classes: $Y \in \left\{ 0, .., C \right\}^L$.
- Minibatch support, with vectorized computation of the messages $\alpha_i(y_i)$ and $\beta_i(y_i)$.
💡 Example
To create a Tree-structured CRF, you must first define the tree encoding the relationships between variables. Let's build a simple CRF for a root variable with two children:
First, define an adjacency matrix $M$ representing the hierarchy, such that $M_{i,j}$ is $1$ if $j$ is a parent of $i$:
adjacency = torch.tensor([
[0, 0, 0],
[1, 0, 0],
[1, 0, 0]
])
Then, create the a CRF with the right number of features, depending on your
feature space, like you would for a torch.nn.Linear
module, to obtain
a Torch model:
crf = torch_treecrf.TreeCRF(n_features=30, hierarchy=hierarchy)
If you wish to use the CRF layer only, use the TreeCRFLayer
module,
which expects and outputs an emission tensor of shape
$(\star, C, L)$, where $\star$ is the minibatch size, $L$ the number of labels and
$C$ the number of class per label.
💭 Feedback
⚠️ Issue Tracker
Found a bug ? Have an enhancement request ? Head over to the GitHub issue tracker if you need to report or ask something. If you are filing in on a bug, please include as much information as you can about the issue, and try to recreate the same bug in a simple, easily reproducible situation.
🏗️ Contributing
Contributions are more than welcome! See
CONTRIBUTING.md
for more details.
⚖️ License
This library is provided under the MIT License.
This library was developed by Martin Larralde during his PhD project at the European Molecular Biology Laboratory in the Zeller team.
📚 References
- [1] Tang, Jie, Mingcai Hong, Juanzi Li, and Bangyong Liang. ‘Tree-Structured Conditional Random Fields for Semantic Annotation’. In The Semantic Web - ISWC 2006, edited by Isabel Cruz, Stefan Decker, Dean Allemang, Chris Preist, Daniel Schwabe, Peter Mika, Mike Uschold, and Lora M. Aroyo, 640–53. Lecture Notes in Computer Science. Berlin, Heidelberg: Springer, 2006. doi:10.1007/11926078_46.
- [2] Pearl, Judea. ‘Reverend Bayes on Inference Engines: A Distributed Hierarchical Approach’. In Proceedings of the Second AAAI Conference on Artificial Intelligence, 133–136. AAAI’82. Pittsburgh, Pennsylvania: AAAI Press, 1982.
- [3] Bach, Francis, and Guillaume Obozinski. ‘Sum Product Algorithm and Hidden Markov Model’, ENS Course Material, 2016. http://imagine.enpc.fr/%7Eobozinsg/teaching/mva_gm/lecture_notes/lecture7.pdf.
- [4] Kschischang, Frank R., Brendan J. Frey, and Hans-Andrea Loeliger. ‘Factor Graphs and the Sum-Product Algorithm’. IEEE Transactions on Information Theory 47, no. 2 (February 2001): 498–519. doi:10.1109/18.910572.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch-treecrf-0.2.0.tar.gz
.
File metadata
- Download URL: torch-treecrf-0.2.0.tar.gz
- Upload date:
- Size: 12.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 323fe782221fa7e6d9d41bd0fc9545feed8e1b10b4d547c4749a3ec92fbab532 |
|
MD5 | 8153fd3529fdb9a6e03c0879d5324282 |
|
BLAKE2b-256 | f8a02a7bfd00ee7f7be32eb816f318a795a5e9b7c246cab4861b60027299762a |
File details
Details for the file torch_treecrf-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torch_treecrf-0.2.0-py3-none-any.whl
- Upload date:
- Size: 9.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fb2b9672ec89de36ea09f7e67333e3aa671dd0ac068c16222862595dc6372652 |
|
MD5 | 7f592c44d9581c04960295f8f2240d75 |
|
BLAKE2b-256 | 03b9f485613f1a431edf14b738bc5f5054f94381a2c1e4710ca0c0b73b72cd51 |