Skip to main content

No project description provided

Project description

https://travis-ci.org/StefReck/MEdgeConv.svg?branch=master https://codecov.io/gh/StefReck/MEdgeConv/branch/master/graph/badge.svg https://badge.fury.io/py/medgeconv.svg

An efficient tensorflow 2 implementation of the edge-convolution layer EdgeConv used in e.g. ParticleNet.

The structure of the layer is as described in ‘ParticleNet: Jet Tagging via Particle Clouds’ https://arxiv.org/abs/1902.08570. Graphs often have a varying number of nodes. By making use of the disjoint union of graphs in a batch, memory intensive operations in this implementation are done only on the actual nodes (and not the padded ones).

Instructions

Install via:

pip install medgeconv

Use e.g. like this:

import medgeconv

nodes = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    next_neighbors=16,
    to_disjoint=True,
    pooling=True,
)((nodes, is_valid, coordinates))

Inputs to EdgeConv are 3 dense tensors: nodes, is_valid and coordinates

  • nodes, shape (batchsize, n_nodes_max, n_features)

    Node features of the graph, padded to fixed size. Valid nodes have to come first, then the padded nodes.

  • is_valid, shape (batchsize, n_nodes_max)

    1 for actual node, 0 for padded node.

  • coordinates, shape (batchsize, n_nodes_max, n_coords)

    Features of each node used for calculating nearest neighbors.

Examples

Example for batchsize = 2, n_nodes_max = 4, n_features = 2:

nodes = np.array([
   [[2., 4.],
    [2., 6.],
    [0., 0.],  # <-- these nodes are padded, their
    [0., 0.]],  #           value doesn't matter

   [[0., 2.],
    [3., 7.],
    [4., 0.],
    [1., 2.]],
])

is_valid = np.array([
    [1, 1, 0, 0],  # <-- 0 defines these nodes as padded
    [1, 1, 1, 1],
])

coordinates = nodes

By using to_disjoint = True, the dense tensors get transformed to the disjoint union. The output is also disjoint, so this only needs to be done once. pooling = True will attach a node-wise global average pooling layer in the end, producing dense tensors again.

A full model could look like this:

import tensorflow as tf
import medgeconv

inp = (nodes, is_valid, coordinates)
x = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    to_disjoint=True,
    batchnorm_for_nodes=True,
)(inp)

x = medgeconv.DisjointEdgeConvBlock(
    units=[128, 128, 128],
)(x)

x = medgeconv.DisjointEdgeConvBlock(
    units=[256, 256, 256],
    pooling=True,
)(x)

output = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inp, output)

To load models, use the custom_objects:

import medgeconv

model = load_model(path, custom_objects=medgeconv.custom_objects)

knn_graph kernel

This package includes a cuda kernel for calculating the k nearest neighbors on a batch of graphs. It comes with a precompiled kernel for the version of tensorflow specified in requirements.txt.

To compile it locally, e.g. for a different version of tensorflow, go to medgeconv/tf_ops and run make clean and then make. This will produce the file medgeconv/tf_ops/python/ops/_knn_graph_ops.so. For details on how to setup the docker environment for compiling, see https://github.com/tensorflow/custom-op .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medgeconv-1.0.1.tar.gz (49.6 kB view details)

Uploaded Source

File details

Details for the file medgeconv-1.0.1.tar.gz.

File metadata

  • Download URL: medgeconv-1.0.1.tar.gz
  • Upload date:
  • Size: 49.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.6.7

File hashes

Hashes for medgeconv-1.0.1.tar.gz
Algorithm Hash digest
SHA256 635ff5c1265e89065998be2af3571918ece2d706d5e6331b399c74a903a92504
MD5 6b3526fa3e933ceb95a9c1c5e6614fa3
BLAKE2b-256 5d183e7e74c2b3480266b52047f6f3cc7f0d2e8739e765cac6a0398015c97fdf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page