No project description provided
Project description
An efficient tensorflow 2 implementation of the edge-convolution layer EdgeConv used in e.g. ParticleNet.
The structure of the layer is as described in ‘ParticleNet: Jet Tagging via Particle Clouds’ https://arxiv.org/abs/1902.08570.
Instructions
Install via:
pip install medgeconv
Use e.g. like this:
import medgeconv
nodes = medgeconv.DisjointEdgeConvBlock(
units=[64, 64, 64],
next_neighbors=16,
)((nodes, coordinates))
Inputs to EdgeConv are 2 ragged tensors: nodes and coordinates
- nodes, shape (batchsize, None, n_features)
Node features of the graph. Secound dimension is the number of nodes, which can vary from graph to graph.
- coordinates, shape (batchsize, None, n_coords)
Features of each node used for calculating nearest neighbors.
Example: Input for a graph with 2 features per node, and all node features used as coordinates.
import tensorflow as tf
nodes = tf.ragged.constant([
# graph 1: 2 nodes
[[2., 4.],
[2., 6.]],
# graph 2: 4 nodes
[[0., 2.],
[3., 7.],
[4., 0.],
[1., 2.]],
], ragged_rank=1)
print(nodes.shape) # output: (2, None, 2)
# using all node features as coordinates
coordinates = nodes
Example
The full ParticleNet for n_features = n_coords = 2, and a dense layer with 2 neurons as the output can be built like this:
import tensorflow as tf
import medgeconv
inp = (
tf.keras.Input((None, 2), ragged=True),
tf.keras.Input((None, 2), ragged=True),
)
x = medgeconv.DisjointEdgeConvBlock(
units=[64, 64, 64],
batchnorm_for_nodes=True,
next_neighbors=16,
)(inp)
x = medgeconv.DisjointEdgeConvBlock(
units=[128, 128, 128],
next_neighbors=16,
)(x)
x = medgeconv.DisjointEdgeConvBlock(
units=[256, 256, 256],
next_neighbors=16,
pooling=True,
)(x)
output = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inp, output)
The last EdgeConv layer has pooling = True. This will attach a node-wise global average pooling layer in the end, producing normal not-ragged tensors again.
The model can then be used on ragged Tensors:
nodes = tf.RaggedTensor.from_tensor(tf.ones((3, 17, 2)))
model.predict((nodes, nodes))
Loading models
To load models, use the custom_objects:
import medgeconv
model = load_model(path, custom_objects=medgeconv.custom_objects)
knn_graph kernel
This package includes a cuda kernel for calculating the k nearest neighbors on a batch of graphs. It comes with a precompiled kernel for the version of tensorflow specified in requirements.txt.
To compile it locally, e.g. for a different version of tensorflow, go to medgeconv/tf_ops and adjust the compile.sh bash script. Running it will download the specified tf dev docker image and produce the file medgeconv/tf_ops/python/ops/_knn_graph_ops.so.
Publications
Results using this model architecture in the context of particle physics were presented at the ICRC 2021 conference https://doi.org/10.22323/1.395.1048 , as well as the VLVnT 2021 https://arxiv.org/abs/2107.13375 .
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file medgeconv-2.2.tar.gz
.
File metadata
- Download URL: medgeconv-2.2.tar.gz
- Upload date:
- Size: 58.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
1e38ede56a2e3bcaa83bac09205d4323e7afd9159ca19e4854c0bbffac3a6082
|
|
MD5 |
a24019fe64e2e31a23253d7ebdb805d1
|
|
BLAKE2b-256 |
b931e9c7518fb9242812496caad419194be9f2e2664338831506f8d8bcb4253b
|
File details
Details for the file medgeconv-2.2-py3-none-any.whl
.
File metadata
- Download URL: medgeconv-2.2-py3-none-any.whl
- Upload date:
- Size: 60.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.5.0 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
fb183e4848f784f2c730dbc574ada70fb1dcfe2a4506bea95f95d280f7409eb5
|
|
MD5 |
2d9376f14457d760c9feb77e1a209398
|
|
BLAKE2b-256 |
a58746050e7af7a857fdf5d2a5f7db20b08cb9e0d36a05e4ad70b00d1a8d3d13
|