Skip to main content

A Tree-LSTM model package for PyTorch

Project description

pytorch-tree-lstm

This repo contains a PyTorch implementation of the child-sum Tree-LSTM model (Tai et al. 2015) implemented with vectorized tree evaluation and batching. This module has been tested with Python 3.6.6, PyTorch 0.4.0, and PyTorch 1.0.1.

High-level Approach

Efficient batching of tree data is complicated by the need to have evaluated all of a node's children before we can evaluate the node itself. To minimize the performance impact of this issue, we break the node evaluation process into steps such that at each step we evaluate all nodes for which all child nodes have been previously evaluated. This allows us to evaluate multiple nodes with each torch operation, increasing computation speeds by an order of magnitude over recursive approaches.

As an example, consider the following tree:

tree

On the first step of the tree calculation, we can evaluate nodes 1 & 3 in parallel as neither has any child nodes. At the second step we are able to evaluate node 2, as its child node 3 was evaluated previously. Lastly we evaluate node 0, which depends on nodes 1 and 2. Doing this we can reduce a four-node computation to three steps. Bigger trees with more leaf nodes will experience larger performance gains.

To facilitate this approach we encode the Tree structure and features into four Tensors. For a tree with N nodes, E edges, and F features, the required Tensors are:

  • features - A size N x F tensor containing the features for each node.
  • adjacency_list - A size E x 2 tensor containing the node indexes of the parent node and child node for every connection in the tree.
  • node_order - A size N tensor containing the calculation step at which a node can be evaluated. Note that the order that node data is stored in features and node_order must be identical.
  • edge_order - A size E tensor containing the calculation step at which each entry in the adjacency_list is needed in order to retrieve the child nodes for a current node. Note that the order that parent-child data is stored in adjacency_list and edge_order must be identical.

node_order and edge_order hold redundant information derivable from the adjacency_list and features; however, precomputing these tensors gives a significant performance improvement due to the current lack of an efficient set intersection function in PyTorch 1.0. The order tensors can be generated using the treelstm.calculate_evaluation_orders function. calculate_evaluation_orders accepts the adjacency_list tensor and the length of the features tensor and returns the two order tensors:

import treelstm
node_order, edge_order = treelstm.calculate_evaluation_orders(adjacency_list, len(features))

The tensor representation of the example tree above would be:

features: tensor([[1., 0.],
                  [0., 1.],
                  [0., 0.],
                  [1., 1.]])

adjacency_list: tensor([[0, 1],
                        [0, 2],
                        [2, 3]])

node_order: tensor([2, 0, 1, 0])

edge_order: tensor([2, 2, 1])

Installation

The pytorch-tree-lstm package can be installed via pip:

pip install pytorch-tree-lstm

Once installed, the library can be imported via:

import treelstm

Usage

The file tree_list.py contains the TreeLSTM module. The module accepts the features, node_order, adjacency_list, edge_order tensors detailed above as input.

These tensors can be batched together by concatenation (torch.cat()) with the exception of the adjacency_list. The adjacency_list contains indexes into the features tensor used to retrieve child features for performing sums over node children, and when batched together these indexes must be adjusted for the new position of the features in the batched tensors.

The treelstm.batch_tree_input function is provided to do this concatenation and adjustment. treelstm.batch_tree_input accepts a list of dictionaries containing fields features, node_order, adjacency_list, and edge_order and returns a dictionary containing those same fields with the individual dictionaries in the list concatenated together and the adjacency_list indexes adjusted, as well as a tree_sizes list storing the size of each tree in the batch. Given a PyTorch Dataset object that returns tree data as a dictionary of tensors with the above keys, treelstm.batch_tree_input is suitable for use as a collate_fn argument to the PyTorch DataLoader object:

import treelstm

train_data_generator = DataLoader(
    TreeDataset(),
    collate_fn=treelstm.batch_tree_input,
    batch_size=64
)

Unbatching the batched tensors can be done via

torch.split(tensor, tree_sizes, dim=0)

Where tree_sizes is a list containing the number of nodes in each tree in the batch. This function is also provided by the treelstm.unbatch_tree_tensor function for convenience. As mentioned above, a tree_sizes list suitable for use by this function is generated by batch_tensors.batch_tree_tensor.

Example

Example code that generates tensors for the four node example tree above and trains a toy classification problem against the Tree labels is available in the example_usage.py script.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-tree-lstm-0.1.1.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_tree_lstm-0.1.1-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-tree-lstm-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch-tree-lstm-0.1.1.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.6

File hashes

Hashes for pytorch-tree-lstm-0.1.1.tar.gz
Algorithm Hash digest
SHA256 7e3d68f4f54674f9e2166dc1ac6fe7bc8585d81f9076eef1c6cfd47d2bda5552
MD5 a07e0facb7a5172b42e858f4e2dcf2d9
BLAKE2b-256 d4ef41a6a09970e3a2f276abf36c11255fbd27e429f5d98dd4dc33ed3bdfb148

See more details on using hashes here.

File details

Details for the file pytorch_tree_lstm-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_tree_lstm-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/41.0.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.6

File hashes

Hashes for pytorch_tree_lstm-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2e9ce988f9b6bf4921a91fbba96480004afbc9a7fe668bf3b67efd7bfc73c97c
MD5 42668cbaaf8993fcdb1efe4a914ea983
BLAKE2b-256 bb8aef9d9adf12eb569cfa011f5d25d5a87af5032a56ae0e1b0ca1a80f2b31ca

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page