Skip to main content

No project description provided

Project description

https://travis-ci.org/StefReck/MEdgeConv.svg?branch=master https://codecov.io/gh/StefReck/MEdgeConv/branch/master/graph/badge.svg https://badge.fury.io/py/medgeconv.svg

An efficient tensorflow 2 implementation of the edge-convolution layer EdgeConv used in e.g. ParticleNet.

The structure of the layer is as described in ‘ParticleNet: Jet Tagging via Particle Clouds’ https://arxiv.org/abs/1902.08570.

Instructions

Install via:

pip install medgeconv

Use e.g. like this:

import medgeconv

nodes = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    next_neighbors=16,
)((nodes, coordinates))

Inputs to EdgeConv are 2 ragged tensors: nodes and coordinates

  • nodes, shape (batchsize, None, n_features)

    Node features of the graph. Secound dimension is the number of nodes, which can vary from graph to graph.

  • coordinates, shape (batchsize, None, n_coords)

    Features of each node used for calculating nearest neighbors.

Example: Input for a graph with 2 features per node, and all node features used as coordinates.

import tensorflow as tf

nodes = tf.ragged.constant([
   # graph 1: 2 nodes
   [[2., 4.],
    [2., 6.]],
   # graph 2: 4 nodes
   [[0., 2.],
    [3., 7.],
    [4., 0.],
    [1., 2.]],
], ragged_rank=1)

print(nodes.shape)  # output: (2, None, 2)

# using all node features as coordinates
coordinates = nodes

Example

The full ParticleNet for n_features = n_coords = 2, and a dense layer with 2 neurons as the output can be built like this:

import tensorflow as tf
import medgeconv

inp = (
    tf.keras.Input((None, 2), ragged=True),
    tf.keras.Input((None, 2), ragged=True),
)
x = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    batchnorm_for_nodes=True,
    next_neighbors=16,
)(inp)

x = medgeconv.DisjointEdgeConvBlock(
    units=[128, 128, 128],
    next_neighbors=16,
)(x)

x = medgeconv.DisjointEdgeConvBlock(
    units=[256, 256, 256],
    next_neighbors=16,
    pooling=True,
)(x)

output = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inp, output)

The last EdgeConv layer has pooling = True. This will attach a node-wise global average pooling layer in the end, producing normal not-ragged tensors again.

The model can then be used on ragged Tensors:

nodes = tf.RaggedTensor.from_tensor(tf.ones((3, 17, 2)))
model.predict((nodes, nodes))

Loading models

To load models, use the custom_objects:

import medgeconv

model = load_model(path, custom_objects=medgeconv.custom_objects)

knn_graph kernel

This package includes a cuda kernel for calculating the k nearest neighbors on a batch of graphs. It comes with a precompiled kernel for the version of tensorflow specified in requirements.txt.

To compile it locally, e.g. for a different version of tensorflow, go to medgeconv/tf_ops and adjust the compile.sh bash script. Running it will download the specified tf dev docker image and produce the file medgeconv/tf_ops/python/ops/_knn_graph_ops.so.

Publications

Results using this model architecture in the context of particle physics were presented at the ICRC 2021 conference https://doi.org/10.22323/1.395.1048 , as well as the VLVnT 2021 https://arxiv.org/abs/2107.13375 .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medgeconv-2.2.tar.gz (58.7 kB view details)

Uploaded Source

Built Distribution

medgeconv-2.2-py3-none-any.whl (60.0 kB view details)

Uploaded Python 3

File details

Details for the file medgeconv-2.2.tar.gz.

File metadata

  • Download URL: medgeconv-2.2.tar.gz
  • Upload date:
  • Size: 58.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for medgeconv-2.2.tar.gz
Algorithm Hash digest
SHA256 1e38ede56a2e3bcaa83bac09205d4323e7afd9159ca19e4854c0bbffac3a6082
MD5 a24019fe64e2e31a23253d7ebdb805d1
BLAKE2b-256 b931e9c7518fb9242812496caad419194be9f2e2664338831506f8d8bcb4253b

See more details on using hashes here.

File details

Details for the file medgeconv-2.2-py3-none-any.whl.

File metadata

  • Download URL: medgeconv-2.2-py3-none-any.whl
  • Upload date:
  • Size: 60.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for medgeconv-2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 fb183e4848f784f2c730dbc574ada70fb1dcfe2a4506bea95f95d280f7409eb5
MD5 2d9376f14457d760c9feb77e1a209398
BLAKE2b-256 a58746050e7af7a857fdf5d2a5f7db20b08cb9e0d36a05e4ad70b00d1a8d3d13

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page