Early placeholder for TGraphX PyPI reservation
Project description
TGraphX
TGraphX is a PyTorch-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their spatial layout. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, especially those that need to preserve local spatial details (e.g., image or volumetric patches).
📄 Preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning
✏️ Authors: Arash Sajjadi, Mark Eramian
🗓️ Published on arXiv, April 2025
Note: TGraphX includes optional skip connections that help with stable gradient flow in deeper GNN stacks. The overall design is rooted in rigorous theoretical and practical foundations, aiming to unify convolutional neural networks (CNNs) with GNN-based relational reasoning.
Table of Contents
- Overview
- Key Features
- Architecture Highlights
- Future Works
- Installation
- Folder Structure
- Core Components
- Layers
- Models
- Configuration Options
- Advanced Topics
- Novelties and Contributions
- Conclusion
- Citations
- License
Overview
TGraphX provides a modular way to create GNNs by combining several components:
-
Graph Representation
AGraphclass holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as[C, H, W]tensors—making it particularly effective for vision tasks. -
Message Passing Layers
Customizable layers process messages between nodes while preserving the spatial layout of features. In TGraphX:- ConvMessagePassing uses
1×1convolutions on concatenated spatial features (e.g.,Conv1×1(Concat(Xi, Xj, Eij))). - DeepCNNAggregator is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e.,
[C, H, W]shape).
- ConvMessagePassing uses
-
Models
Pre-built models combine a CNN encoder with GNN layers:- CNN Encoder processes raw image patches into spatial feature maps (e.g.,
[C, H, W]). - Optional Pre-Encoder (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
- Unified CNN‑GNN Model uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
- An extra skip connection (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
- CNN Encoder processes raw image patches into spatial feature maps (e.g.,
Key Features
-
Support for Arbitrary Dimensions
Handle vectors, 2D images, or even volumetric 3D patches as node features. -
Spatial Message Passing
Messages preserve spatial dimensions (e.g.,[C, H, W]), letting convolutional filters capture local patterns and avoid destructive flattening of features. -
Deep Aggregation
A deep CNN aggregator (with multiple3×3convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion. -
Optional Pre‑Encoder
Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline. -
Flexible Data Loading
TGraphX includes custom dataset and data loader classes (GraphDatasetandGraphDataLoader) for direct graph-based batching. -
Configurable Skip Connections
Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
Architecture Highlights
Preserving Spatial Fidelity
Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout [C, H, W] at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
Convolution-Based Message Passing
TGraphX implements message passing via Conv1×1(Concat(Xi, Xj, Eij)). This approach:
- Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
- Preserves the dimension
[C, H, W], avoiding vector flattening. - Optionally incorporates edge features
Eijfor more advanced relational cues (e.g., distances, bounding-box overlaps).
Deep CNN Aggregator with Residuals
Messages from neighbors are aggregated (summed or averaged) and then passed to a deep CNN aggregator that uses multiple 3×3 convolutions with residual skips. This design:
- Prevents the overwriting of original features by always adding
Aggregator(mj)to the old node stateXj. - Facilitates stable gradient flow in deep GNN stacks.
- Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
End-to-End Differentiability
Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains fully differentiable in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
Future Works
-
Scalability and Data Requirements
Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies. -
Domain-Specific Customization
Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets. -
Alternative Edge Definitions
Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes. -
Multimodal and Real-Time Extensions
Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
Installation
-
Clone the Repository
git clone https://github.com/YourUsername/TGraphX.git cd TGraphX
-
Set Up the Environment
Use the providedenvironment.ymlto create a conda environment:conda env create -f environment.yml conda activate tgraphx
-
Install PyTorch
Install a recent version of PyTorch (GPU version if possible). -
Install Additional Dependencies
pip install -r requirements.txt
-
Editable Mode (Optional)
pip install -e .
Folder Structure
TGraphX/
├── __init__.py
├── core/
│ ├── dataloader.py
│ ├── graph.py
│ └── utils.py
├── layers/
│ ├── aggregator.py
│ ├── attention_message.py
│ ├── base.py
│ ├── conv_message.py
│ └── safe_pool.py
├── models/
│ ├── cnn_encoder.py
│ ├── cnn_gnn_model.py
│ ├── graph_classifier.py
│ ├── node_classifier.py
│ └── pre_encoder.py
├── environment.yml
└── README.md
Core Components
Graph and Data Loading
-
Graph&GraphBatch
Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch. -
GraphDataset&GraphDataLoader
Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
Utility Functions
-
load_config
Load YAML/JSON configuration files to keep hyperparameters consistent across experiments. -
get_device
Utility to automatically detect and return the correct device (GPU or CPU).
Layers
Base Layer
TensorMessagePassingLayer
An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g.,[C, H, W]).
Convolution-Based Message Passing
ConvMessagePassing
Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a1×1convolution:Mij = Conv1×1(Concat(Xi, Xj, Eij))
- Message Phase: Each pair
(i, j)of nodes exchanges messages computed by a1×1conv. - Aggregation + Residual Update: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a residual skip.
- Message Phase: Each pair
Deep CNN Aggregator
DeepCNNAggregator
A stack of3×3convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:X'_j = X_j + A( m_j )
wherem_j = sum of messages to node j. Residual connections ensure stable gradient flow.
Attention-Based Message Passing
AttentionMessagePassing
An alternative that uses1×1convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
Safe Pooling
SafeMaxPool2d
A robust pooling module that checks if spatial dimensions[H, W]are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
Models
CNN Encoder and Pre-Encoder
-
CNNEncoder
Converts raw patches ([C_in, patch_H, patch_W]) into spatial feature maps (e.g.,[C_out, H2, W2]). Includes:- Multiple 3×3 conv blocks with BN, ReLU, and dropout.
- Optional residual connections.
- Safe max pooling if the spatial size remains large.
-
Optional Pre‑Encoder
- If
use_preencoderisTrue, a ResNet‑like (or fully custom) module first processes each patch, returning refined features with the same spatial structure. pretrained_resnetcan load weights from a standard ResNet‑18 for transfer learning.
- If
Unified CNN‑GNN Model
CNN_GNN_Model
A single pipeline that:- Splits the image into patches, optionally uses
PreEncoder. - Feeds patches into
CNNEncoderto get[C, H, W]maps. - Builds a graph where each node holds a
[C, H, W]map. - Applies multiple GNN layers (like
ConvMessagePassing+DeepCNNAggregator). - Optionally uses a skip connection to combine CNN outputs with GNN outputs.
- Performs final spatial pooling before classification.
- Splits the image into patches, optionally uses
Graph & Node Classification Models
-
GraphClassifier
Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier. -
NodeClassifier
Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
Configuration Options
TGraphX is highly configurable. Some key parameters include:
config = {
"cnn_params": {
"in_channels": 3,
"out_features": 64,
"num_layers": 2,
"hidden_channels": 64,
"dropout_prob": 0.3,
"use_batchnorm": True,
"use_residual": True,
"pool_layers": 2,
"debug": False,
"return_feature_map": True
},
"use_preencoder": False,
"pretrained_resnet": False,
"preencoder_params": {
"in_channels": 3,
"out_channels": 32,
"hidden_channels": 32
},
"gnn_in_dim": (64, 5, 5),
"gnn_hidden_dim": (128, 5, 5),
"num_classes": 10,
"num_gnn_layers": 4,
"gnn_dropout": 0.3,
"residual": True,
"aggregator_params": {
"num_layers": 4,
"dropout_prob": 0.3,
"use_batchnorm": True
}
}
cnn_params: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).use_preencoder: Boolean indicating whether to preprocess patches with a custom or pretrained module.pretrained_resnet: IfTrue, loads pretrained ResNet-18 weights in the pre-encoder.gnn_in_dim,gnn_hidden_dim: Shapes of the node features in GNN layers. Each dimension can be[C, H, W].num_gnn_layers: How many message passing layers to stack.aggregator_params: Depth, dropout, and BN usage in the aggregator.residual: Enables skip connections in the GNN layers.
Advanced Topics
Theoretical Insights
-
Universal Approximation via Deep CNN
Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps. -
Residual Learning for Gradient Flow
Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end. -
Spatial vs. Flattened Features
Preserving the[C, H, W]layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
Possible Extensions
-
Adaptive Edge Construction
Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds. -
Mixed Modalities
Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams. -
Task-Specific Losses
Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop. -
Performance Optimizations
Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
Novelties and Contributions
TGraphX departs from traditional GNN designs in several ways:
-
Full Spatial Fidelity
Each node in the graph remains a multi-dimensional feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks. -
Convolution-Based Message Passing
Employing1×1convolutions on[C, H, W]feature maps lets neighboring patches exchange information at every pixel location, ensuring alignment and detail retention. -
Deep Residual Aggregation
Multiple3×3CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner. -
End-to-End Differentiable
From raw image patches to final classification or detection outputs, all steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning. -
Modular & Extensible
- Allows easy substitution of the aggregator or attention-based message passing layers.
- Accommodates multiple data modalities (image, volumetric, or otherwise).
- Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
These innovations build on earlier GNN research while pushing further to retain all the valuable local details that are typically lost in flattened GNN nodes.
Conclusion
We have presented TGraphX, an architecture aimed at integrating convolutional neural networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity. By retaining multi-dimensional CNN feature maps as node representations and employing convolution-based message passing, TGraphX captures both local and global spatial context. Our experiments—particularly those involving detection refinement—demonstrate its potential to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX offers a flexible framework that other researchers can adapt or extend. This integration of CNN-based feature extraction with GNN-based relational reasoning is a promising direction for future AI and vision research.
Citations
@misc{sajjadi2025tgraphxtensorawaregraphneural,
title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
author={Arash Sajjadi and Mark Eramian},
year={2025},
eprint={2504.03953},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.03953},
}
License
TGraphX is released under the MIT License. See the LICENSE file for more details.
Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!
If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tgraphx-0.0.1.tar.gz.
File metadata
- Download URL: tgraphx-0.0.1.tar.gz
- Upload date:
- Size: 8.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d57391477208913cb59aa66b7a7ab3fc6ca6ebb3baa79c259358bac56fe4237a
|
|
| MD5 |
ae96cee8d6418a4a2298eea44defd521
|
|
| BLAKE2b-256 |
5db5f1936fdf44916cfc2902f49c551ecebc6cd9b583e3f237af2a36e68452bb
|
File details
Details for the file tgraphx-0.0.1-py3-none-any.whl.
File metadata
- Download URL: tgraphx-0.0.1-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e1ad39e03d09f1d0d52b5779f4ddc07a9a7d7f7c137a3ae220113a3f22d91ec5
|
|
| MD5 |
d78c1ba32f0245154507a0e838acc1b9
|
|
| BLAKE2b-256 |
7989e92f93a7b040464af0eb19eac696daa526759ece1ed6016a42a0486a9be9
|