Skip to main content

No project description provided

Project description

https://travis-ci.org/StefReck/MEdgeConv.svg?branch=master https://codecov.io/gh/StefReck/MEdgeConv/branch/master/graph/badge.svg https://badge.fury.io/py/medgeconv.svg

An efficient tensorflow 2 implementation of the edge-convolution layer EdgeConv used in e.g. ParticleNet.

The structure of the layer is as described in ‘ParticleNet: Jet Tagging via Particle Clouds’ https://arxiv.org/abs/1902.08570. Graphs often have a varying number of nodes. By making use of the disjoint unioon of graphs in a batch, memory intensive operations in this implementation are done only on the actual nodes. This is faster if the number of nodes varies between graphs in the batch.

Install e.g. via:

pip install git+https://github.com/StefReck/MEdgeConv.git#egg=MEdgeConv

Use e.g. like this:

import medgeconv

nodes = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    next_neighbors=16,
    to_disjoint=True,
    pooling=True)((nodes, is_valid, coordinates))

Inputs to EdgeConv are 3 dense tensors: nodes, is_valid and coordinates

  • nodes, shape (batchsize, n_nodes_max, n_features)

    Node features of the graph, padded to fixed size.

  • is_valid, shape (batchsize, n_nodes_max)

    1 for actual node, 0 for padded node.

  • coordinates, shape (batchsize, n_nodes_max, n_coords)

    Features of each node used for calculating nearest neighbors.

By using to_disjoint = True, the dense tensors get transformed to the disjoint union. The output is also disjoint. pooling = True will attach a node-wise global average pooling layer in the end.

A full model could look like this:

import tensorflow as tf
import medgeconv

inp = (nodes, is_valid, coordinates)
x = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    to_disjoint=True,
    batchnorm_for_nodes=True)(inp)

x = medgeconv.DisjointEdgeConvBlock(
    units=[128, 128, 128])(x)

x = medgeconv.DisjointEdgeConvBlock(
    units=[256, 256, 256],
    pooling=True)(x)

output = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inp, output)

To load models, use the custom_objects:

import medgeconv

model = load_model(path, custom_objects=medgeconv.custom_objects)

Remarks:

  • Batchsize has to be fixed (i.e. use Input(batch_size=bs, …))

  • in nodes array, valid nodes have to come first, then the padded nodes

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medgeconv-0.2.2.tar.gz (10.8 kB view details)

Uploaded Source

File details

Details for the file medgeconv-0.2.2.tar.gz.

File metadata

  • Download URL: medgeconv-0.2.2.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.0.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.6.7

File hashes

Hashes for medgeconv-0.2.2.tar.gz
Algorithm Hash digest
SHA256 005269b5c36217fafb9e1e5e22ea850c4403f2ee2760e8a2f1a20457e266c926
MD5 b82c6ee394d546f583fd6e01cbaebe1e
BLAKE2b-256 b2dbcb1cbc36f5b6227ed12c6e22df1815f0dffdaef52142aa141970f8ae50e8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page