Skip to main content

Equivariant convolution on molecule structure point clouds

Project description

This repository contains the Cosmo neural network lift and convolution layers. For a usage example and reproduction of the results of the RECOMB 2026 submission "Gaining mechanistic insight from geometric deep learning on molecule structures through equivariant convolution", see https://github.com/BorgwardtLab/RECOMB2026Cosmo.

Installation: pip install cosmic-torch or pip install git+https://github.com/BorgwardtLab/Cosmo

Cosmo

Cosmo is a neural network architecture based on message passing on geometric graphs of molecule structures. It applies a convolutional filter by translating it to vertices and rotating it towards neighbors. The resulting feature activation (message) is passed to the neighbor that the filter was pointed at. This way, large geometric patterns can be modeled with a template-matching objective by using multiple Cosmo layers. A Cosmo network is equivariant to translation and rotation, and highly interpretable as its weight matrices can be linearly combined and its filter poses can be reconstructed geometrically. For more details, please see the paper.

Example Usage

Cosmo layers operate on lifted geometric graphs. These are computed from an adjacency matrix of the data, either given by e.g. atomic bond connectivity, or constructed by e.g. k-NN:

adj = torch_geometric.nn.knn_graph(coords, k, batch_index)

where coords are the input point coordinates of the data, k is a hyperparameter, and batch_index assigns each node to an instance in the batch (compare the computing principles of PyG, which we highly recommend to use).

Given coordinates, node features (e.g. one-hot encoded atom type), and the adjacency we can lift the input graph:

L = Lift2D()(features, coords, adj, batch_index) # or Lift3D()

The L namespace contains everything that we need to compute in subsequent Cosmo layers:

features = layer(L.source, L.target, L.features, L.hood_coords)

After the Cosmo layers we need to undo the lift operation (lowering) to obtain features on the input graph. This is done by aggregating the edge/triangle features to the nodes, which yields a standard graph object that can be further computed on with PyG layers, for example.

node_features = Lower(agg="max")(features, L.lifted2node, num_nodes)

Or, if features should be aggregated directly to the instance (graph) level:

graph_features = Lower(agg="max")(features, L.lifted2inst, num_instances)

An entire Cosmo network for a node classification task could look like this:

from cosmic import *
import torch.nn as nn

class CosmoModel(nn.Module):

    def __init__(self):
        self.lift = Lift3D()
        self.lower = Lower()
        self.cosmo_layers = nn.ModuleList([
            NeuralFieldCosmo(in_channels=5, out_channels=128, dim=3),
            NeuralFieldCosmo(in_channels=128, out_channels=128, dim=3),
            NeuralFieldCosmo(in_channels=128, out_channels=10, dim=3)
        ])

    def forward(self, node_features, coords, adj, batch_index, num_nodes):
        L = self.lift(node_features, coords, adj, batch_index)
        features = L.features
        for layer in self.cosmo_layers:
            features = layer(L.source, L.target, features, L.hood_coords)
        node_features = self.lower(features, L.lifted2node, num_nodes)
        # there could be some classic GNN-layers here, or an MLP head
        return node_features

Citation

TBD

License

TBD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cosmic_torch-0.1.9.tar.gz (529.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cosmic_torch-0.1.9-py3-none-any.whl (8.5 kB view details)

Uploaded Python 3

File details

Details for the file cosmic_torch-0.1.9.tar.gz.

File metadata

  • Download URL: cosmic_torch-0.1.9.tar.gz
  • Upload date:
  • Size: 529.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cosmic_torch-0.1.9.tar.gz
Algorithm Hash digest
SHA256 3cf608bde5a9d43dc51947354e6f165418ae9aa2a110db2d011b9f5b76d25a47
MD5 2d17c087693d0360fef8e3f25c3a4c1c
BLAKE2b-256 68f0170d8dcca62b4b472f5ad2da168d6e77cc66d4921d8e01f37a9fa03ddb5c

See more details on using hashes here.

File details

Details for the file cosmic_torch-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: cosmic_torch-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 8.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cosmic_torch-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 7a7e793923aedbcb5a9eb4fbd3756d3e59b660e1d695d3861c887c98a0b27ffb
MD5 1aa69e605c752af466e03d3a0d2b127c
BLAKE2b-256 e4aa958e4cf22bb029e3e7f7407c658fb47e534a5b472e990c1e3709ccaf4408

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page