Skip to main content

Tensor-aware graph neural networks preserving spatial node feature layouts

Project description

TGraphX logo

TGraphX

Tests PyPI version Python 3.9+ PyTorch 1.13+ License: MIT

๐Ÿ“„ Preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning ยท Sajjadi & Eramian, arXiv 2025

TGraphX is a PyTorch library for graph neural networks whose node features are multi-dimensional tensors โ€” such as [C, H, W] image-patch feature maps โ€” rather than flat vectors. Convolutional message passing operates directly on spatial feature maps at each node, so local structure is never destroyed by flattening.


Colab tutorial

TGraphX includes a hands-on Google Colab notebook that installs the latest PyPI release and walks through the main API features interactively:

Open in Colab

Open the TGraphX Colab Tutorial โ†’

The notebook covers:

  • CPU/GPU environment checks
  • Vector node classification
  • 2-D spatial graph classification using image patches
  • 3-D volumetric graph classification
  • Graph-level and node-level regression
  • Edge prediction / edge classification
  • Layer zoo: ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer, TensorGINLayer, LinearMessagePassing, legacy AttentionMessagePassing
  • Edge weights and edge features
  • Dashboard-compatible run files

Honesty note: All tasks in the notebook use controlled synthetic data designed to verify installation, API behaviour, device compatibility, and gradient flow. They are not benchmark results and make no real-world performance or state-of-the-art claims.


The problem TGraphX solves

Standard GNN frameworks (PyG, DGL) expect flat vector node features. Flattening a [C, H, W] feature map into a [CยทHยทW] vector discards the spatial structure that makes CNNs effective. TGraphX keeps each node's representation as a tensor and applies 1ร—1 convolutions during message passing, so every neighbourhood aggregation step acts like a miniature CNN across neighbouring feature maps.


What is currently implemented

All "tensor-aware" GNN layers in this list operate on spatial node feature maps [N, C, H, W] and preserve the spatial layout through message passing. They are adaptations of the canonical algorithms โ€” not drop-in clones of PyTorch Geometric's vector-feature implementations.

Component Class Input node shape Notes
Graph data structure Graph, GraphBatch any Validated; .to(device)
Tensor-aware GCN-style message passing ConvMessagePassing [N, C, H, W] aggr="sum" or "mean"; 1ร—1 conv messages + deep CNN aggregator
Tensor-aware GAT (multi-head) TensorGATLayer [N, C, H, W] True GAT: softmax over incoming edges per destination, per head; scalar attention per (edge, head)
Tensor-aware GraphSAGE TensorGraphSAGELayer [N, C, H, W] Separate self / neighbour 1ร—1 Conv2d; mean or max aggregation; optional L2 normalise
Tensor-aware GIN / GINEConv TensorGINLayer [N, C, H, W] (1+ฮต)ยทh_j + ฮฃ h_i; default 1ร—1 Conv MLP, learnable ฮต, optional GINEConv edge term
Spatial-gating message passing (legacy) AttentionMessagePassing [N, C, H, W] or [N, D] Per-edge sigmoid gating โ€” not GAT. Kept for backward compatibility; use TensorGATLayer for true GAT.
Vector message passing LinearMessagePassing [N, D] Base layer with linear projections
Custom layer base class TensorMessagePassingLayer any Override message / update; base handles aggregation
CNN patch encoder CNNEncoder [N, C_in, pH, pW] Outputs spatial feature maps [N, C_out, H', W']
Optional pre-encoder PreEncoder [N, C_in, pH, pW] Custom or pretrained ResNet-18
Unified CNN-GNN model CNN_GNN_Model [N, C, pH, pW] Takes pre-split patches; user supplies edge_index
Graph classification GraphClassifier [N, C, H, W] Mean / sum / max readout
Node classification NodeClassifier [N, D] Vector features only
Dataset & loader GraphDataset, GraphDataLoader โ€” Wraps torch.utils.data
Utilities load_config, get_device โ€” YAML/JSON config; CUDAโ†’MPSโ†’CPU

Datasets, transforms, metrics, benchmarks (v0.2.9)

TGraphX includes a unified dataset registry, native synthetic datasets, folder-backed datasets, and optional adapters for torchvision, PyG, DGL, and OGB.

from tgraphx.datasets import list_datasets, dataset_info, get_dataset

print(list_datasets())                                # registered names
ds = get_dataset("synthetic:patch_graph", num_graphs=32, seed=0)
g = ds[0]
  • TGraphX does not redistribute third-party datasets โ€” no bundled datasets ship in the wheel or sdist. External adapters delegate download and parsing to the upstream library; downloads happen only when the user passes download=True.
  • Cache layout: <root> (or $TGRAPHX_DATA, or ~/.cache/tgraphx/datasets). Inspect with tgraphx.datasets.cache_summary(); clear with clear_cache(...).
  • Transforms (tgraphx.transforms) are deterministic-when-seeded, non-mutating-by-default, and importing them does not pull in any optional dependency.
  • Metrics (tgraphx.metrics) are pure-PyTorch.
  • Benchmarks (benchmarks/benchmark_*.py) accept --small and --output JSON; their results are engineering reproducibility signals, not real-world performance claims.
  • Synthetic datasets are sanity / tutorial / trainability datasets, not benchmarks.

Detailed write-ups in docs/datasets.md, docs/transforms.md, docs/metrics.md, docs/benchmarks.md, and the license/citation policy.

Optional dataset extras

Adapter Install
Native synthetic / folder (no extra)
torchvision-backed datasets (already a TGraphX base dependency)
PyG dataset adapter pip install "tgraphx[pyg]"
DGL dataset adapter follow upstream install (DGL wheels are platform-sensitive; we don't pin them)
OGB dataset adapter pip install "tgraphx[ogb]"
Image folder dataset (PIL) pip install "tgraphx[pillow]"

Importing tgraphx, tgraphx.datasets, tgraphx.transforms, or tgraphx.metrics does not import torch_geometric / dgl / ogb.


Current scope and boundaries

TGraphX is a focused library for tensor-aware patch-graph GNNs. Here is what is stable, what is experimental, and what is intentionally out of scope. Full details are in docs/limitations.md and docs/roadmap.md.

Optional and experimental capabilities

Feature Status Install / usage
TensorGATLayer(attention_mode="channel") ๐Ÿงช Experimental Constructor argument; per-channel attention
TensorGATLayer(chunk_size=K) two-pass chunked forward โœ… Stable forward(chunk_size=K) (log-sum-exp)
GraphTransformerLayer (vector node features) ๐Ÿงช Experimental from tgraphx.layers.graph_transformer import GraphTransformerLayer
Positional / structural encodings (degree, Laplacian, adjacency bias) ๐Ÿงช Experimental tgraphx.layers.transformer_encodings
HeteroGraph + HeteroGraphBatch + HeteroConv + classifiers ๐Ÿงช Experimental tgraphx.HeteroGraph, HeteroGraphBatch, tgraphx.layers.hetero.HeteroConv, tgraphx.models.hetero_models.*
TemporalGraphSequence + TemporalGraphBatch + readout + classifiers ๐Ÿงช Experimental tgraphx.TemporalGraphSequence, TemporalGraphBatch, tgraphx.layers.temporal_readout, tgraphx.models.temporal_models.*
Subgraph / k-hop / neighbour / random-walk sampling โœ… Stable tgraphx.sampling, tgraphx.SubgraphDataLoader, tgraphx.NeighborSamplerLoader
Hetero / temporal sampling (v0.2.8) โœ… Stable tgraphx.hetero_induced_subgraph, tgraphx.hetero_neighbor_sample, tgraphx.temporal_window_sample, tgraphx.temporal_window_sample_batch
Distributed (DDP) helpers โœ… Stable tgraphx.distributed; never auto-initialises DDP
MLflowLogger โœ… Opt-in pip install "tgraphx[mlflow]"; lazy mlflow import
PyG / DGL data converters (homogeneous + hetero) โœ… Opt-in tgraphx.interop (lazy imports)
Learned graph helpers โœ… Stable tgraphx.learned_graph
Patch helper padding="auto" โœ… Stable image_to_patches(imgs, ps, padding="auto")
Hardware monitoring dashboard ๐Ÿ”’ Opt-in pip install "tgraphx[monitoring]"
TensorBoard logging ๐Ÿ”’ Opt-in pip install "tgraphx[tracking]"

Scope boundaries (true design limits)

These are intentional, not bugs. Detailed write-ups live in docs/limitations.md; the roadmap is in docs/roadmap.md.

  • Supported node feature ranks: [N, D], [N, C, H, W], and [N, C, D, H, W]. A universal arbitrary-rank layer that works across every existing layer is a v0.3 design discussion.
  • GAT per-pixel / per-voxel attention: naive score tensors would be O(E ยท K ยท H ยท W) โ€” memory-prohibitive for typical spatial GNN workloads. Memory-safe variants (factorised / windowed / low-rank) are deferred until designed honestly. Per-channel attention is shipped as attention_mode="channel".
  • PyG / DGL drop-in compatibility: TGraphX is not a replacement. tgraphx.interop ships data converters; layer APIs differ.
  • Multi-GPU training framework: TGraphX provides rank-zero / world-size DDP helpers and a single-process smoke example. Production-grade automatic multi-GPU training is the user's responsibility (and is intentionally not bundled).
  • Recurrent temporal memory (TGN / TGAT-style): temporal workflows use a stateless snapshot-loop pattern. Memory-aware temporal architectures are an open design question.
  • Profiling and file writes: disabled by default; every logger, dashboard, and checkpoint write is opt-in.

Performance

Environment and hardware report

from tgraphx.performance import env_report, estimate_message_memory, recommended_device

print(env_report())                           # Python/PyTorch/CUDA/MPS info
print(env_report(include_hardware=True))      # + CPU/RAM/CUDA memory (needs psutil)
print(env_report(include_sensors=True))       # + GPU util/temp (needs pynvml)

dev = recommended_device()                    # CUDA > MPS > CPU

# Estimate peak message-buffer memory before running
m = estimate_message_memory(num_edges=1024, out_shape=(64, 8, 8))
print(f"~{m['total_mb']:.1f} MB  ({m['note']})")

Benchmarks

# Layer throughput (CPU-safe, all flags optional)
python benchmarks/benchmark_layers.py --layer gat --nodes 64 --edges 256 \
    --shape 8,4,4 --device cpu --iters 10

# CUDA + AMP + backward
python benchmarks/benchmark_layers.py --layer conv --nodes 256 --edges 2048 \
    --shape 32,8,8 --device cuda --amp 1 --backward 1

# Save JSON result
python benchmarks/benchmark_layers.py --layer gin --shape 8,4,4 \
    --output results/gin.json

# Graph builder timing
python benchmarks/benchmark_graph_builders.py --small    # CI-safe
python benchmarks/benchmark_graph_builders.py            # full

torch.compile and AMP

python examples/torch_compile_benchmark.py   # eager vs compiled, correctness check
python examples/mixed_precision_inference.py  # autocast forward demo (finite-output check)
python examples/memory_report.py             # env report + memory estimates

AMP policy:

Backend Recommended dtype Status Notes
CPU bfloat16 โœ… Tested Covered by tests/test_amp_compile.py in full Linux CI
CUDA float16 / bfloat16 โš ๏ธ Best-effort Behaviour fixed in v0.2.2; bfloat16 needs Ampere+
MPS โ€” โš ๏ธ Best-effort PyTorch operator coverage varies; not in CI

v0.2.2 fixes: broadcast_edge_weight casts edge weights to activation dtype; TensorGATLayer casts attention weights before index_add_; edge_softmax upcasts to fp32 for numerical stability and casts back. See docs/performance.md for full details.

Optional chunked forward (ConvMessagePassing)

Reduce peak edge-buffer memory by processing edges in chunks:

from tgraphx.layers.conv_message import ConvMessagePassing

layer = ConvMessagePassing(in_shape=(32, 8, 8), out_shape=(32, 8, 8), aggr="sum")
# Same output as unchunked; lower peak memory for large E
out = layer(x, edge_index, chunk_size=512)

Supported aggregations: "sum" and "mean". "max" falls back to the standard path with a warning. All four message-passing layers accept chunk_size in forward():

  • ConvMessagePassing: sum / mean (max falls back).
  • TensorGraphSAGELayer: mean / max.
  • TensorGINLayer: sum.
  • TensorGATLayer: two-pass log-sum-exp. Output matches unchunked within float32 tolerance.

Hardware compatibility

Platform Forward AMP torch.compile CI coverage Notes
CPU โœ… โš ๏ธ bfloat16 only โœ… Full CI (Ubuntu) Compile overhead may dominate small graphs
CUDA โœ… โš ๏ธ float16 (op-dependent) โœ… Local tests only index_add_ ops require dtype match; no GPU runners in CI
MPS (Apple Silicon) โœ… โš ๏ธ limited โš ๏ธ Smoke CI (macOS) Best-effort; PyTorch op coverage varies
Linux โœ… โœ… โœ… Full CI (ubuntu-latest, Py 3.10/3.11/3.12) Primary CI platform
Windows โœ… โœ… โœ… Smoke CI (Py 3.11) Imports + build + dashboard CLI smoke
macOS โœ… โš ๏ธ limited โš ๏ธ Smoke CI (Py 3.11) Same surface as Windows smoke

Training utilities

TGraphX includes lightweight training helpers โ€” not a full training framework. All logging and file writes are off by default.

Training loop helpers

from tgraphx.training import train_epoch, evaluate, fit
import torch.nn.functional as F

# One-line training loop
history = fit(
    model, train_loader, val_loader=val_loader,
    epochs=20,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
    loss_fn=F.cross_entropy,
    log_level=1,           # print per-epoch summary
)
# โ†’ [{"epoch":0, "train_loss":0.9, "val_loss":0.85}, ...]

Supported batch formats: GraphBatch (with graph_labels / node_labels) and (Tensor, Tensor) tuples.

Standalone helpers

from tgraphx.training import (
    set_seed,            # seeds torch / numpy / random
    count_parameters,    # trainable parameter count
    save_checkpoint,     # torch.save wrapper
    load_checkpoint,     # returns saved epoch number
    accuracy,            # multi-class argmax accuracy
    mean_absolute_error, mean_squared_error,
)

CSV logging (dashboard-compatible)

from tgraphx.tracking import CSVLogger

with CSVLogger("runs/my_run") as logger:
    history = fit(model, train_loader, ..., logger=logger)
# writes runs/my_run/metrics.csv with UTC timestamps

TensorBoard logging (optional)

from tgraphx.tracking import TensorBoardLogger  # lazy import

# pip install tensorboard  or  pip install "tgraphx[tracking]"
with TensorBoardLogger("runs/tb") as tb:
    history = fit(model, train_loader, ..., logger=tb)

Nothing is written unless you explicitly pass a logger.

MLflow logging (optional)

from tgraphx.tracking import MLflowLogger   # pip install mlflow

with MLflowLogger(run_name="my_run", experiment="gnn") as mlf:
    history = fit(model, train_loader, logger=mlf, ...)

Dashboard

TGraphX includes a local training dashboard โ€” off by default, zero external dependencies, no telemetry.

Quick start

# Launch (localhost only, no token needed)
tgraphx-dashboard --logdir runs/demo
# โ†’ http://127.0.0.1:8765

# LAN access โ€” explicit token required
tgraphx-dashboard --logdir runs/demo \
  --host 0.0.0.0 --token MY_SECRET_TOKEN

# Auto-generated token
tgraphx-dashboard --logdir runs/demo --host 0.0.0.0 --token auto

# Offline HTML snapshot โ€” no server needed
tgraphx-dashboard --logdir runs/demo --export-html snapshot.html
# Python API โ€” non-blocking background thread
from tgraphx.dashboard import launch_dashboard_background, export_dashboard_html
server = launch_dashboard_background("runs/demo", port=8765)
# ... training loop ...
server.shutdown()

# Offline snapshot
export_dashboard_html("runs/demo", "snapshot.html")

Dashboard features

Section Contents
Overview Status chip, epoch progress, live loss, elapsed / ETA
Metrics SVG line charts, window selector, EMA smoothing, per-chart CSV/SVG export
Graph Graph summary, degree stats, graph_stats.json precomputed cards, SVG preview
Hardware CPU/RAM/GPU/CUDA/MPS, power draw, thermal status (optional psutil/pynvml)
Logs Scrollable metric table with CSV export
Config run_metadata.json rendered safely
Tools Copy URL, export buttons, refresh controls
TV mode Full-screen large-font passive monitoring

Key features:

  • Incremental updates โ€” browser requests only new rows via ?since_row=N
  • Multi-run selector โ€” point at a parent directory; select runs by name
  • Color-blind-safe palette โ€” Okabe-Ito toggle, persisted in localStorage
  • Accessible โ€” skip link, ARIA labels, focus-visible, reduced-motion support
  • Export โ€” metrics CSV, per-chart CSV/SVG, print/save PDF, offline HTML snapshot
  • Responsive โ€” phone, tablet, desktop, TV/large-monitor layouts
  • Pause/resume polling, configurable refresh interval

Security model

Scenario Token required?
--host 127.0.0.1 (default) No
--host 0.0.0.0 + connecting from localhost No
--host 0.0.0.0 + connecting from another device Yes
Starting LAN mode without --token Refused at startup

Read-only ยท no external CDN ยท no telemetry ยท no token leakage in API responses ยท path-traversal protected.

Log files

metrics.csv          โ€” epoch,train_loss,val_loss,... (ISO-8601 UTC timestamp column)
run_metadata.json    โ€” run name, status, total_epochs, device, task (free-form dict)
graph_metadata.json  โ€” optional graph summary + edge_index for preview (โ‰ค200 nodes)
graph_stats.json     โ€” optional precomputed stats (write with write_graph_stats())
from tgraphx import write_graph_stats
write_graph_stats({"num_nodes": 100, "num_edges": 400, "density": 0.04},
                  "runs/demo/graph_stats.json")

Hardware monitoring (optional)

pip install "tgraphx[monitoring]"   # psutil + pynvml

Missing packages show a compact "unavailable" reason per row โ€” no broken charts.


Privacy and local-first behavior

TGraphX is designed to be entirely local and private:

Behavior Default
Telemetry / analytics None โ€” never
Remote calls at import None
Dashboard Off โ€” launch explicitly
CSV metric logging Off โ€” create CSVLogger explicitly
TensorBoard logging Off โ€” create TensorBoardLogger explicitly
Hardware monitoring Off โ€” pass include_hardware=True to env_report
Checkpoints Off โ€” call save_checkpoint explicitly
Graph serialization Off โ€” include edge_index in JSON manually
Background threads None (unless launch_dashboard_background is called)
File writes Only to paths you explicitly provide
Reads Dashboard reads only inside --logdir
External CDN / assets None โ€” dashboard is fully self-contained

No ~/.tgraphx directory or user-level config is created by default.


Installation

pip install tgraphx

Optional extras:

pip install "tgraphx[tracking]"    # TensorBoard integration
pip install "tgraphx[monitoring]"  # psutil + pynvml (dashboard hardware panel)
pip install "tgraphx[dev]"         # pytest, build, twine

For a specific PyTorch build (e.g. CPU-only or a particular CUDA version), install PyTorch before TGraphX:

# CPU-only example
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install tgraphx

Install from source:

git clone https://github.com/arashsajjadi/TGraphX.git
cd TGraphX
pip install -e ".[dev]"

See pytorch.org for GPU-specific install commands. A Conda environment file is provided at environment.yml.


Quickstart

Vector node features (simplest case):

import torch
from tgraphx import Graph, LinearMessagePassing, build_model, fit

# 8 nodes, 32-dimensional vector features
x = torch.randn(8, 32)
src = torch.arange(8)
edge_index = torch.stack([src, (src + 1) % 8])

g = Graph(x, edge_index)
layer = LinearMessagePassing(in_shape=(32,), out_shape=(64,))
out = layer(g.node_features, g.edge_index)   # [8, 64]
out.sum().backward()

Spatial node features โ€” [C, H, W] preserved through message passing:

import torch
from tgraphx import Graph, ConvMessagePassing

N, C, H, W = 6, 16, 8, 8
node_features = torch.randn(N, C, H, W)
src = torch.arange(N)
edge_index = torch.stack([src, (src + 1) % N])

g = Graph(node_features, edge_index)
layer = ConvMessagePassing(in_shape=(C, H, W), out_shape=(32, H, W))
out = layer(g.node_features, g.edge_index)   # [6, 32, 8, 8]
out.sum().backward()

Graph Builders

TGraphX ships pure-PyTorch graph builders that return edge_index tensors ([2, E], dtype=torch.long) ready for any GNN layer or Graph constructor. They create fixed, rule-based adjacency structures โ€” they do not implement learned adjacency.

Builder API

Function Description Key params
build_grid_graph(rows, cols) 2-D 4-connected grid directed, self_loops
build_grid_graph_3d(depth, rows, cols) 3-D 6-connected grid directed, self_loops
build_fully_connected_graph(num_nodes) Complete graph self_loops
build_knn_graph(coords, k) k-nearest-neighbour directed, self_loops
build_radius_graph(coords, radius) All pairs within radius directed, self_loops
build_iou_graph(boxes, threshold) Bounding-box IoU โ‰ฅ threshold directed, self_loops
build_random_graph(num_nodes, num_edges) Uniform random sample directed, self_loops, seed

O(Nยฒ) warning: build_knn_graph and build_radius_graph call torch.cdist internally, which requires O(Nยฒ) time and memory. For large graphs (N > 10 000), use approximate-NN libraries instead.

O(Nยฒ) warning: build_fully_connected_graph emits Nยท(Nโˆ’1) edges. Memory grows quadratically with node count.

Grid graph quickstart

import torch
from tgraphx import Graph, build_grid_graph
from tgraphx.layers.conv_message import ConvMessagePassing

# 3ร—3 patch grid โ€” 9 nodes, each with a [4, 8, 8] spatial feature
node_features = torch.randn(9, 4, 8, 8)
edge_index = build_grid_graph(3, 3, directed=False, self_loops=True)
# edge_index: [2, 33]  (24 neighbour + 9 self-loop edges)

g = Graph(node_features, edge_index)

layer = ConvMessagePassing(in_shape=(4, 8, 8), out_shape=(8, 8, 8))
out = layer(g.node_features, g.edge_index)   # [9, 8, 8, 8]

2-D image patch graph

import torch
from tgraphx import build_grid_graph, image_to_patches
from tgraphx.layers.gat import TensorGATLayer

# Extract 4ร—4 patches from a [B, C, H, W] image
images = torch.randn(2, 3, 8, 8)
patches = image_to_patches(images, patch_size=4)   # [2, 4, 3, 4, 4]

# Build the patch grid graph
edge_index = build_grid_graph(2, 2, directed=False, self_loops=True)

# Run GAT on one image's patches
x = patches[0]   # [4, 3, 4, 4]
gat = TensorGATLayer(in_channels=3, out_channels=8, num_heads=2, spatial_rank=2)
out = gat(x, edge_index)   # [4, 8, 4, 4]

3-D volume patch graph

import torch
from tgraphx import build_grid_graph_3d, volume_to_patches
from tgraphx.layers.gat import TensorGATLayer

# Extract 4ร—4ร—4 patches from a [B, C, D, H, W] volume
volumes = torch.randn(1, 2, 8, 8, 8)
patches = volume_to_patches(volumes, patch_size=4)   # [1, 8, 2, 4, 4, 4]

# Build the 3-D patch grid graph
edge_index = build_grid_graph_3d(2, 2, 2, directed=False, self_loops=True)

# Run GAT on one volume's patches
x = patches[0]   # [8, 2, 4, 4, 4]
gat = TensorGATLayer(in_channels=2, out_channels=4, num_heads=2, spatial_rank=3)
out = gat(x, edge_index)   # [8, 4, 4, 4, 4]

Patch helper API

Function Input Output Notes
patch_grid_shape(H, W, patch_size, stride) โ€” (n_h, n_w) Raises if not exactly covered
image_to_patches(images, patch_size, stride) [B, C, H, W] [B, P, C, ph, pw] Row-major; matches grid node order
volume_patch_grid_shape(D, H, W, patch_size, stride) โ€” (n_d, n_h, n_w) Raises if not exactly covered
volume_to_patches(volumes, patch_size, stride) [B, C, D, H, W] [B, P, C, pd, ph, pw] Depth-row-col order; matches 3-D grid

Concept: tensor-aware node features

In a standard GNN a node carries a vector x_i โˆˆ โ„^d. In TGraphX a node carries a tensor X_i โˆˆ โ„^{Cร—Hร—W}. Every layer follows the standard message-passing template:

M_{iโ†’j} = ฯ†( X_i, X_j, E_{iโ†’j} )                   # per-edge messages
A_j     = AGG_{i โˆˆ N(j)} M_{iโ†’j}                    # permutation-invariant aggregation
X'_j    = ฯˆ( X_j, A_j )                             # update

Each layer instantiates this with its own (ฯ†, AGG, ฯˆ):

Layer ฯ† (message) AGG ฯˆ (update)
ConvMessagePassing Conv1ร—1(Concat(X_i, X_j[, E_ij])) sum / mean / max DeepCNNAggregator(A_j) (+ optional residual)
TensorGATLayer ฮฑ_{ij}^k ยท W^k X_i with ฮฑ_{ij}^k = softmax_i(LeakyReLU(a_dstยทpool(W^k X_j) + a_srcยทpool(W^k X_i) + b^k(e_ij))) sum (weighted) concat or mean over heads, optional residual
TensorGraphSAGELayer W_neigh(X_i) (+ optional spatial cat or vector bias from e_ij) mean / max W_self(X_j) + AGG, optional L2 normalise
TensorGINLayer X_i or ReLU(X_i + ฯ†_e(e_ij)) (GINEConv) sum MLP((1+ฮต)ยทX_j + ฮฃ_i M_ij)
LinearMessagePassing Linear(Concat(x_i, x_j[, e_ij])) sum / mean / max identity (override to customise)

Spatial dimensions H and W are preserved through every learned transform. All aggregations are permutation-invariant over the order of incoming edges, and every layer is permutation-equivariant over node reindexing (verified by tests/test_math.py).

The graph structure โ€” which nodes are connected โ€” is not learned by the model. Users supply edge_index based on domain knowledge (e.g., spatial proximity of patches, IoU overlap of bounding boxes, kNN on patch centres).

Edge feature formats per layer

Layer Vector [E, D_e] Spatial [E, C_e, H, W]
ConvMessagePassing โœ— โœ“ (concatenated along channels; channel count must equal node channel count)
TensorGATLayer โœ“ (additive attention bias on logits) โš ๏ธ accepted; mean-pooled to scalar attention bias (no per-pixel attention)
TensorGraphSAGELayer โœ“ (additive channel bias post-W_neigh) โœ“ (concatenated to source)
TensorGINLayer โœ“ (broadcast bias before ReLU) โœ“ (1ร—1 Conv2d projection)

TensorGATLayer spatial edge features: [E, C_e, H, W] (or [E, C_e, D, H, W] for 3-D nodes) are accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do not need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured GAT) raises NotImplementedError. Use TensorGraphSAGELayer or TensorGINLayer for full spatial edge-feature processing (no pooling).


Factory API

make_layer โ€” create any layer by name

from tgraphx import make_layer

# 2-D spatial GAT with 2 attention heads
layer = make_layer("gat", in_shape=(8, 4, 4), out_shape=(16, 4, 4), heads=2, residual=True)

# Vector linear message passing
layer = make_layer("linear", in_shape=(32,), out_shape=(64,), aggr="mean")
Name Layer class Shape support
"conv" ConvMessagePassing 2-D / 3-D spatial
"gat" TensorGATLayer 2-D / 3-D spatial
"sage" TensorGraphSAGELayer 2-D / 3-D spatial
"gin" TensorGINLayer 2-D / 3-D spatial
"linear" LinearMessagePassing vector only
"legacy_attention" AttentionMessagePassing vector / 2-D spatial

build_model โ€” complete task model

from tgraphx import build_model

# Node classification on vector features
model = build_model(
    task="node_classification",
    layer="linear",
    in_shape=(32,),
    hidden_shape=(64,),
    num_layers=3,
    num_classes=5,
)
out = model(x, edge_index)    # [N, 5]

# Graph classification on 2-D image patches
model = build_model(
    task="graph_classification",
    layer="gat",
    in_shape=(3, 4, 4),       # [C, ph, pw]
    hidden_shape=(8, 4, 4),
    num_layers=2,
    num_classes=10,
    heads=2,
    pooling="mean",
)
out = model(x, edge_index, batch=batch)   # [G, 10]

Task support matrix

Task Vector [N,D] 2-D spatial [N,C,H,W] 3-D volumetric [N,C,D,H,W]
node_classification โœ“ โœ“ โœ“
node_regression โœ“ โœ“ โœ“
graph_classification โœ“ โœ“ โœ“
graph_regression โœ“ โœ“ โœ“
edge_prediction โœ“ โœ“ (spatial pool) โœ“ (spatial pool)
link_prediction deferred deferred deferred

Spatial / volumetric tasks: after the GNN stack the factory applies global spatial average-pooling to flatten [N, C, *spatial] โ†’ [N, C] before the linear head. Spatial resolution is preserved inside every GNN layer and only collapsed at the final readout.

build_model_from_config โ€” config-driven construction

from tgraphx import build_model_from_config

# From a Python dict (no eval, no exec)
config = {
    "model": {
        "task": "graph_classification",
        "layer": "gat",
        "in_shape": [8, 4, 4],
        "hidden_shape": [16, 4, 4],
        "num_layers": 2,
        "num_classes": 3,
        "heads": 2,
        "residual": True,
        "dropout": 0.1,
    }
}
model = build_model_from_config(config)

# From a JSON file
model = build_model_from_config("config.json")

# From a YAML file (requires PyYAML)
model = build_model_from_config("config.yaml")

Examples

Run every fast example in one go:

python examples/run_all_fast_examples.py

Individual demos:

# Tensor-aware GCN-style spatial message passing
python examples/minimal_spatial_message_passing.py

# Graph classification โ€” short training loop with synthetic data
python examples/minimal_graph_classifier.py

# Tensor-aware multi-head GAT (attention weights sum to 1 per destination per head)
python examples/tensor_gat_minimal.py

# Tensor-aware GraphSAGE (mean / max / with-edge-features variants)
python examples/tensor_graphsage_minimal.py

# Custom user-defined message-passing layer subclass
python examples/custom_message_passing.py

# Trainability sanity: tiny overfit per GNN family
python examples/tiny_overfit_tensor_gat.py
python examples/tiny_overfit_edge_features.py

# Deep 8-layer stack gradient sanity
python examples/gradient_sanity_stack.py

# Graph builders + patch helpers
python examples/directed_vs_undirected_graphs.py
python examples/image_patch_graph.py
python examples/volume_patch_graph.py
python examples/gnn_family_with_graph_builders.py

# Factory + model examples (numbered)
python examples/01_vector_node_classification.py
python examples/02_spatial_graph_classification.py
python examples/03_volumetric_graph_classification.py
python examples/04_config_based_model.py
python examples/05_edge_prediction.py

# Hetero / temporal / sampling / transformer demos
python examples/hetero_graph_batch_demo.py
python examples/hetero_graph_classifier_demo.py
python examples/temporal_graph_batch_demo.py
python examples/temporal_graph_classifier_demo.py
python examples/neighbor_sampling_demo.py
python examples/sampling_demo_v028.py        # random-walk + hetero + temporal sampling (v0.2.8)
python examples/graph_transformer_demo.py
python examples/gat_chunking_demo.py
python examples/v024_new_features.py

# Distributed-training helpers (single-process smoke)
python examples/ddp_training_smoke.py

# Performance / hardware
python examples/memory_report.py
python examples/mixed_precision_inference.py
python examples/torch_compile_benchmark.py

# Training + dashboard
python examples/training_minimal_fit.py
python examples/training_with_csvlogger.py
python examples/training_with_tensorboard.py
python examples/training_with_dashboard.py
python examples/checkpoint_save_load.py

API reference

Graph

from tgraphx import Graph

g = Graph(
    node_features,                 # torch.Tensor  [N, ...]              required
    edge_index=None,               # torch.LongTensor [2, E] or None     optional
    edge_weight=None,              # torch.Tensor [E]                    optional
    edge_features=None,            # torch.Tensor [E, ...]               optional
    node_labels=None,              # torch.Tensor [N, ...]               optional
    edge_labels=None,              # torch.Tensor [E, ...]               optional
    graph_label=None,              # torch.Tensor (any shape)            optional
    metadata=None,                 # dict                                optional
)
g.clone()                          # deep copy (tensors and metadata)
g.to("cuda", dtype=torch.float32)  # moves all tensor fields; dtype only
                                   #   applies to floating-point tensors
g.cpu(); g.cuda()                  # convenience aliases
g.add_self_loops(); g.remove_self_loops()
g.make_undirected(reduce="mean")   # symmetrize, coalesce duplicates
g.is_undirected()                  # structural check (ignores weights)
g.validate()                       # re-run validation after manual mutation
g.num_nodes, g.num_edges, g.feature_shape, g.edge_feature_shape
g.has_edges, g.has_edge_weight, g.has_edge_features

Supported per-node feature layouts: [N, D], [N, C, H, W], and storage-level [N, C, D, H, W]. Edge features mirror this: [E, D_e], [E, C_e, H, W], and storage-level [E, C_e, D, H, W].

Graph.__init__ raises immediately on:

  • non-Tensor inputs
  • node_features or edge_features with fewer than 2 dimensions
  • edge_index with wrong shape, wrong dtype (torch.long required), or out-of-range indices
  • edge_weight that is not 1-D, or whose length differs from E
  • per-edge tensors (edge_weight / edge_features / edge_labels) supplied without an edge_index
  • device mismatch between node_features and any other tensor field
  • length mismatch between node_features / edge_index and any per-node / per-edge tensor

GraphBatch

from tgraphx import GraphBatch

batch = GraphBatch([g1, g2, g3])
# batch.node_features  [N_total, ...]
# batch.edge_index     [2, E_total]   โ€” indices offset per graph
# batch.edge_weight    [E_total]      โ€” concatenated if every graph has it
# batch.edge_features  [E_total, ...] โ€” concatenated if every graph has it
# batch.node_labels    [N_total, ...] โ€” concatenated if every graph has it
# batch.edge_labels    [E_total, ...] โ€” concatenated if every graph has it
# batch.graph_labels   [B, ...]       โ€” stacked if every graph has it
# batch.metadata       list[Any]      โ€” verbatim, length B (None entries OK)
# batch.batch          [N_total]      โ€” graph membership (dtype=long)
batch.to("cuda")

All graphs must share the same per-node feature shape (and per-edge feature shape, when present). Passing graphs with different spatial sizes raises a ValueError with a descriptive message. Optional per-edge tensors must be present on every graph that has edges, or none โ€” mixing-some-with-none is rejected so per-edge data is never silently dropped.

ConvMessagePassing

from tgraphx.layers import ConvMessagePassing

layer = ConvMessagePassing(
    in_shape=(C, H, W),          # tuple: per-node input shape (spatial only)
    out_shape=(C_out, H, W),     # H and W must stay equal to in_shape's H, W
    aggr="sum",                  # "sum" (default) | "mean" | "max"
    use_edge_features=False,     # set True to concatenate edge tensors into messages
    aggregator_params=None,      # dict forwarded to DeepCNNAggregator; e.g.
                                 #   {"num_layers": 2, "dropout_prob": 0.1}
    residual=False,              # add skip connection when in_shape == out_shape
)
out = layer(node_features, edge_index)              # [N, C_out, H, W]
out = layer(node_features, edge_index, edge_features)  # with edge features

aggr="max" is supported via scatter_reduce_(reduce='amax'). When chunk_size is also set, aggr="max" falls back to the unchunked path with a warnings.warn. Use GraphClassifier(pooling="max") for graph-level max readout.

AttentionMessagePassing

from tgraphx.layers import AttentionMessagePassing

# Spatial path
layer = AttentionMessagePassing(in_shape=(C, H, W), out_shape=(C_out, H, W))

# Vector path (also supported)
layer = AttentionMessagePassing(in_shape=(D,), out_shape=(D_out,))

out = layer(node_features, edge_index)   # [N, C_out, H, W] or [N, D_out]

Important: This layer computes attn = sigmoid(qยทk / โˆšd) independently per edge. Attention weights are not normalised over each destination node's neighbourhood (no softmax). This differs from standard GAT. For true GAT, use TensorGATLayer below.

TensorGATLayer (true multi-head GAT)

True GAT-style attention adapted to spatial features. For every destination node j and every head k, attention weights satisfy ฮฃ_i ฮฑ_ij^k = 1.

from tgraphx.layers import TensorGATLayer

# 4 heads ร— 8 channels each = 32 output channels (heads concatenated)
layer = TensorGATLayer(
    in_channels=16,
    out_channels=32,        # divisible by num_heads when concat_heads=True
    num_heads=4,
    concat_heads=True,      # False โ†’ average heads, output is per-head channels
    negative_slope=0.2,     # LeakyReLU before edge softmax
    attn_dropout=0.0,       # dropout on attention weights (training only)
    residual=True,          # auto 1ร—1 projection if in/out channels differ
    bias=True,
    add_self_loops=False,   # True ensures every node has at least 1 in-edge
    use_edge_features=False,  # set True to enable EGAT-style vector edge bias
    edge_dim=None,            # required when use_edge_features=True
)
out = layer(x, edge_index)                                  # [N, 32, H, W]

# Inspect attention weights (e.g. for visualisation or testing):
out, attn = layer(x, edge_index, return_attention=True)
# attn shape: [E, num_heads]; sums to 1 over incoming edges per destination per head.

# EGAT-style vector edge attention bias (e.g. relative box coords, IoU,
# distances, scale ratio):
layer_e = TensorGATLayer(
    in_channels=16, out_channels=32, num_heads=4,
    use_edge_features=True, edge_dim=3,
)
out = layer_e(x, edge_index, edge_features=ef)   # ef: [E, 3]

Attention is scalar per (edge, head) in this implementation: the projected query and key feature maps are mean-pooled over H ร— W before being scored, while the value tensors keep their full spatial layout during aggregation. Per-pixel and per-channel attention modes are not yet supported.

Spatial edge features ([E, C_e, H, W] for spatial_rank=2; [E, C_e, D, H, W] for spatial_rank=3) are accepted: spatial dims are mean-pooled to a channel vector before the per-(edge, head) attention bias projection (spatial dims need not match node spatial dims). Use TensorGraphSAGELayer or TensorGINLayer for full spatial edge-feature processing without pooling.

TensorGraphSAGELayer

Tensor-aware GraphSAGE: h_j' = W_self(h_j) + W_neigh(AGG_i h_i).

from tgraphx.layers import TensorGraphSAGELayer

layer = TensorGraphSAGELayer(
    in_channels=16,
    out_channels=32,
    aggr="mean",            # "mean" or "max"
    normalize=False,        # True โ†’ L2-normalise output channel vector per pixel
    bias=True,
    residual=False,
    use_edge_features=False,
    edge_dim=None,             # required when use_edge_features=True
    edge_features_kind="spatial",  # or "vector" โ€” see below
)
out = layer(x, edge_index)                                  # [N, 32, H, W]

# Spatial edge features [E, edge_dim, H, W] โ€” concatenated to source.
layer_s = TensorGraphSAGELayer(
    in_channels=16, out_channels=32,
    use_edge_features=True, edge_dim=4, edge_features_kind="spatial",
)
out = layer_s(x, edge_index, edge_features=ef_spatial)

# Vector edge features [E, edge_dim] โ€” projected to channel bias and added
# to W_neigh(h_src) before aggregation.
layer_v = TensorGraphSAGELayer(
    in_channels=16, out_channels=32,
    use_edge_features=True, edge_dim=3, edge_features_kind="vector",
)
out = layer_v(x, edge_index, edge_features=ef_vector)        # ef_vector: [E, 3]

Isolated nodes (no incoming edges) receive only the self transform โ€” the neighbour aggregate is zero.

TensorGINLayer

Tensor-aware GIN / GINEConv: h_j' = MLP((1+ฮต)ยทh_j + ฮฃ_i m_ij).

from tgraphx.layers import TensorGINLayer

# Default 1ร—1 Conv MLP (preserves spatial layout)
layer = TensorGINLayer(
    in_channels=16,
    out_channels=32,
    hidden_channels=24,     # defaults to out_channels
    eps=0.0,
    train_eps=False,        # set True to make ฮต a learnable scalar parameter
    use_batchnorm=False,
)
out = layer(x, edge_index)                                  # [N, 32, H, W]

# Custom MLP (any nn.Module mapping [N, in_channels, H, W] โ†’ [N, out_channels, H, W])
import torch.nn as nn
custom_mlp = nn.Sequential(
    nn.Conv2d(16, 24, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(24, 32, kernel_size=1),
)
layer = TensorGINLayer(in_channels=16, out_channels=32, mlp=custom_mlp)

# GINEConv-style spatial edge inclusion: messages = ReLU(h_src + ฯ†(e_ij))
layer_s = TensorGINLayer(
    in_channels=16, out_channels=32,
    use_edge_features=True, edge_dim=4, edge_features_kind="spatial",
)
out = layer_s(x, edge_index, edge_features=ef_spatial)

# Vector edge features [E, edge_dim] โ€” projected to [E, in_channels, 1, 1]
# and broadcast over H ร— W before ReLU.
layer_v = TensorGINLayer(
    in_channels=16, out_channels=32,
    use_edge_features=True, edge_dim=3, edge_features_kind="vector",
)
out = layer_v(x, edge_index, edge_features=ef_vector)        # ef_vector: [E, 3]

Custom layers via TensorMessagePassingLayer

import torch, torch.nn as nn
from tgraphx.layers import TensorMessagePassingLayer

class MyConv(TensorMessagePassingLayer):
    def __init__(self, c_in, c_out):
        super().__init__(in_shape=(c_in,), out_shape=(c_out,), aggr="mean")
        self.W_g = nn.Conv2d(c_in, c_out, kernel_size=1)
        self.W_v = nn.Conv2d(c_in, c_out, kernel_size=1)

    def message(self, src, dest, edge_attr):
        gate = torch.sigmoid(self.W_g(src + dest))
        return gate * self.W_v(src)

    def update(self, node_feature, aggregated_message):
        return aggregated_message

The base class handles per-edge gather and aggregation (sum or mean) for arbitrary trailing tensor shapes. See examples/custom_message_passing.py.

CNNEncoder

from tgraphx.models import CNNEncoder

enc = CNNEncoder(
    in_channels=3,
    out_features=64,
    num_layers=3,         # total Conv2d blocks
    hidden_channels=64,
    dropout_prob=0.3,
    use_batchnorm=True,
    use_residual=True,    # residual skip in intermediate blocks
    pool_layers=1,        # how many blocks include SafeMaxPool2d(2)
    return_feature_map=True,   # True โ†’ [N, out_features, H', W']
                               # False โ†’ [N, out_features] (global avg pool)
    pre_encoder=None,     # optional PreEncoder instance
)
features = enc(patches)   # patches: [N, in_channels, patch_H, patch_W]

GraphClassifier

from tgraphx.models import GraphClassifier

clf = GraphClassifier(
    in_shape=(C, H, W),
    hidden_shape=(C_hidden, H, W),
    num_classes=5,
    num_layers=2,
    aggr="sum",
    pooling="mean",        # "mean" | "sum" | "max"
)
logits = clf(
    node_features,         # [N, C, H, W]
    edge_index,            # [2, E]
    batch=batch_vector,    # [N] โ€” required for graph-level output
    edge_features=None,    # optional
)                          # โ†’ [num_graphs, num_classes]

NodeClassifier

from tgraphx.models import NodeClassifier

nc = NodeClassifier(
    in_shape=(64,),        # vector features only
    hidden_shape=(128,),
    num_classes=3,
    num_layers=2,
)
logits = nc(node_features, edge_index)   # [N, num_classes]

CNN_GNN_Model

A full CNN โ†’ GNN โ†’ classify pipeline that accepts pre-split node patches.

from tgraphx.models import CNN_GNN_Model

model = CNN_GNN_Model(
    cnn_params=dict(
        in_channels=3,
        out_features=64,
        num_layers=2,
        hidden_channels=64,
        dropout_prob=0.0,
        use_batchnorm=False,
        use_residual=False,
        pool_layers=1,
        return_feature_map=True,
    ),
    gnn_in_dim=(64, 8, 8),      # must match CNN output shape exactly
    gnn_hidden_dim=(64, 8, 8),
    num_classes=10,
    num_gnn_layers=2,
    gnn_dropout=0.3,            # forwarded to DeepCNNAggregator
    residual=True,              # per-layer skip connection
    skip_cnn_to_classifier=False,
)

# raw_patches: pre-split by the user; shape [N, in_channels, pH, pW]
logits = model(raw_patches, edge_index)            # [N, num_classes]  (node-level)
logits = model(raw_patches, edge_index, batch=b)   # [G, num_classes]  (graph-level)

get_device

from tgraphx.core.utils import get_device

device = get_device()             # CUDA (if available) โ†’ MPS โ†’ CPU
device = get_device(device_id=1)  # specific CUDA device

Shape conventions

Tensor Shape dtype Notes
node_features [N, D], [N, C, H, W], or [N, C, D, H, W] float vector, 2-D spatial, or 3-D volumetric; N = number of nodes
node_features (vector) [N, D] float For NodeClassifier / LinearMessagePassing
edge_index [2, E] torch.long Row 0 = source nodes, row 1 = destination nodes
edge_features [E, ...] float Optional; length must equal E
batch [N] torch.long Maps each node to its graph index

Device support

Device Status
CPU โœ… Tested in full CI (Ubuntu 3.10 / 3.11 / 3.12)
NVIDIA CUDA โœ… Locally tested (PyTorch 2.x); no GPU runners in CI
Apple Silicon MPS โš ๏ธ Best-effort; macOS smoke CI imports + builds only
Multi-GPU (DDP) ๐Ÿงฐ Helpers shipped (tgraphx.distributed); full automatic DDP training framework intentionally out of scope
from tgraphx.core.utils import get_device

device = get_device()
model.to(device)
g.to(device)
batch.to(device)

Supported Python and PyTorch versions

Python PyTorch Status
3.10 โ‰ฅ 1.13 โœ… CI (ubuntu-latest)
3.11 โ‰ฅ 1.13 โœ… CI (ubuntu-latest, macos-latest, windows-latest)
3.12 โ‰ฅ 1.13 โœ… CI (ubuntu-latest)
3.9 โ‰ฅ 1.13 Listed in classifiers; not in CI matrix โ€” should work but unverified
3.13 โ‰ฅ 1.13 Listed in classifiers; not in CI matrix โ€” should work but unverified

Support status

Legend

Label Meaning
โœ… Stable Tested in CI; API is stable
๐Ÿงช Experimental Available but not yet guaranteed-stable
โš ๏ธ Best-effort Works in practice; known constraints documented
โณ Planned On roadmap for a future release
โŒ Not supported Out of scope for the current release
๐Ÿ”’ Opt-in Disabled by default; explicitly enabled by the user

Backend support

Backend Forward AMP torch.compile CI coverage Status Notes
CPU โœ… โš ๏ธ bfloat16 โœ… Full CI (Ubuntu Py 3.10/3.11/3.12) โœ… Stable Compile overhead for small graphs
CUDA โœ… โš ๏ธ op-dependent โœ… Local only โ€” no GPU runners โœ… Stable index_add_ requires dtype match under float16
MPS (Apple Silicon) โœ… โš ๏ธ limited โš ๏ธ partial macOS smoke CI (import + build) โš ๏ธ Best-effort PyTorch operator coverage varies
Linux โœ… โœ… โœ… Full CI (ubuntu-latest) โœ… Stable Primary CI platform
Windows โœ… โœ… โœ… Smoke CI (Py 3.11) โš ๏ธ Best-effort Imports + build + twine + dashboard CLI
macOS โœ… โš ๏ธ limited โš ๏ธ Smoke CI (Py 3.11) โš ๏ธ Best-effort MPS path; same smoke as Windows
Multi-GPU (DDP) ๐Ÿงฐ helpers โš ๏ธ user-managed โš ๏ธ user-managed No CI ๐Ÿงฐ Helpers only Rank-zero / world-size / barrier; full multi-GPU framework out of scope

โš ๏ธ Best-effort backend: MPS support depends on PyTorch operator coverage per release. CPU workflows are fully tested; MPS-specific AMP/compile paths may fall back or be skipped.

โš ๏ธ Windows/macOS smoke CI: automated import, build, twine-check, and dashboard CLI smoke on every pull request. Full pytest still runs on Ubuntu only.

Feature support

Feature Status Notes
Vector node features [N, D] โœ… Stable LinearMessagePassing, "linear" factory
2-D spatial node features [N, C, H, W] โœ… Stable All four spatial layers
3-D volumetric node features [N, C, D, H, W] โœ… Stable spatial_rank=3
Arbitrary-rank tensors (rank โ‰ฅ 4 trailing) โ›” Out of scope Only vector, 2-D, and 3-D layouts ship โ€” see Scope boundaries
Edge weights [E] โœ… Stable All layers
Vector edge features [E, D_e] โœ… Stable GAT, SAGE, GIN
Spatial edge features [E, C_e, H, W] โš ๏ธ Best-effort ConvMP (concat); GAT (mean-pooled scalar bias); SAGE/GIN (full)
Volumetric edge features [E, C_e, D, H, W] โš ๏ธ Best-effort Same as spatial; pass spatial_rank=3
GraphTransformerLayer (vector node features) ๐Ÿงช Experimental tgraphx.layers.graph_transformer.GraphTransformerLayer; positional / Laplacian / edge-bias supported
Heterogeneous graphs (container + batch + HeteroConv + classifiers) ๐Ÿงช Experimental HeteroGraph, HeteroGraphBatch, HeteroConv, HeteroGraphClassifier, HeteroNodeClassifier; vector node features
Temporal graphs (container + batch + readout + classifiers) ๐Ÿงช Experimental TemporalGraphSequence, TemporalGraphBatch, temporal_readout, TemporalGraphClassifier, TemporalGraphRegressor; snapshot-loop pattern (no recurrent memory module)
Learned graph construction (soft adjacency, edge scorer) โœ… Stable tgraphx.learned_graph โ€” discrete top-k is non-differentiable
PyG / DGL data converters (homogeneous + hetero) โœ… Opt-in tgraphx.interop โ€” lazy imports; data-only, not API replacement
MLflowLogger โœ… Opt-in Lazy mlflow import; pip install "tgraphx[mlflow]"
Dashboard ๐Ÿ”’ Opt-in Launch explicitly; zero overhead when off
Offline dashboard export โœ… Stable --export-html or export_dashboard_html()
Multi-run dashboard โœ… Stable Point --logdir at parent directory
Hardware monitoring ๐Ÿ”’ Opt-in pip install "tgraphx[monitoring]"
TensorBoard logging ๐Ÿ”’ Opt-in pip install "tgraphx[tracking]"; TensorBoardLogger

Scalability support

Feature Status Notes
ConvMessagePassing chunked forward โœ… Stable aggr="sum" / "mean"; max falls back with warning
TensorGraphSAGELayer chunked forward โœ… Stable mean and max; pass chunk_size=K to forward()
TensorGINLayer chunked forward โœ… Stable sum aggregation; pass chunk_size=K to forward()
TensorGATLayer chunked forward โœ… Stable Two-pass log-sum-exp; pass chunk_size=K to forward(); output matches unchunked within float32 tolerance
build_grid_graph / build_grid_graph_3d โœ… Stable O(E) โ€” scales well
build_random_graph โœ… Stable O(E) sample mode for large N
build_knn_graph / build_radius_graph โš ๏ธ Best-effort O(Nยฒ) time; chunk_size=K reduces peak memory to O(Kร—N)
build_fully_connected_graph โš ๏ธ Best-effort O(Nยฒ) edges; N > 5 000 emits warning
build_iou_graph โš ๏ธ Best-effort O(Nยฒ) IoU; chunk_size=K reduces peak memory to O(Kร—N)
Dashboard metrics API โœ… Stable Incremental ?since_row=N; --max-metric-rows cap; byte-seek tail-read
Subgraph / k-hop / neighbour sampling โœ… Stable tgraphx.sampling + SubgraphDataLoader / NeighborSamplerLoader
Random-walk sampling โœ… Stable random_walk_sample(graph, seeds, walk_length, โ€ฆ); deterministic with seed
Hetero sampling (induced + per-relation neighbour) โœ… Stable tgraphx.hetero_induced_subgraph, tgraphx.hetero_neighbor_sample
Temporal window sampling (sequence + batch) โœ… Stable tgraphx.temporal_window_sample, tgraphx.temporal_window_sample_batch
Distributed (DDP) helpers โœ… Stable tgraphx.distributed: rank-zero, world-size, barrier; never auto-initialises DDP

โš ๏ธ Scalability warning: build_knn_graph, build_radius_graph, build_fully_connected_graph, and build_iou_graph use pairwise torch.cdist or enumerate all pairs. Memory and time grow as O(Nยฒ). A warnings.warn is emitted when node count exceeds the threshold (10 000 for kNN/radius, 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.

Attention support

Feature Status Notes
Scalar attention per (edge, head) โœ… Stable Default attention_mode="scalar"
Per-channel attention per (edge, head, channel) ๐Ÿงช Experimental attention_mode="channel"; score tensor [E, K, C_head]
Vector edge attention bias โœ… Stable use_edge_features=True, edge_dim=D
Spatial edge attention bias (2-D / 3-D) โš ๏ธ Best-effort Accepted; mean-pooled to scalar before projection
Per-pixel attention โ›” Out of scope Naive [E, K, H, W] score tensor is memory-prohibitive; deferred until a memory-safe variant is designed
Per-voxel attention โ›” Out of scope Same reason as per-pixel; [E, K, D, H, W] score tensor is memory-prohibitive

Limitations

TGraphX is a focused tensor-aware GNN library, not a drop-in replacement for PyTorch Geometric or DGL. Detailed write-ups live in docs/limitations.md; the high-level boundaries are:

  • Not a PyG / DGL drop-in. tgraphx.interop ships data converters (homogeneous + hetero, lazy imports), but layer APIs and call conventions differ.
  • AttentionMessagePassing is not GAT. It uses per-edge sigmoid gating without destination-wise softmax. Use TensorGATLayer for true multi-head GAT.
  • Scalar (default) and per-channel (attention_mode="channel", experimental) attention are the supported TensorGATLayer modes. Per-pixel / per-voxel attention is intentionally not implemented โ€” naive scores [E, K, H, W] are memory-prohibitive and a memory-safe variant has not yet been designed.
  • GAT edge features (vector or matching-rank spatial โ†’ mean-pooled). TensorGATLayer accepts [E, edge_dim] vectors and matching-rank spatial tensors ([E, edge_dim, H, W] for spatial_rank=2, [E, edge_dim, D, H, W] for spatial_rank=3). Spatial tensors are mean-pooled before the per-(edge, head) attention bias projection. Mismatched-rank edges raise NotImplementedError. Use TensorGraphSAGELayer or TensorGINLayer for spatial edge features without pooling.
  • Supported node feature ranks: vector [N, D], 2-D spatial [N, C, H, W], 3-D volumetric [N, C, D, H, W]. Universal arbitrary-rank tensor support across every layer is a v0.3 design discussion.
  • Hetero / temporal workflows are experimental. HeteroConv, HeteroGraphClassifier, HeteroNodeClassifier, TemporalGraphClassifier, TemporalGraphRegressor exist + are tested, but the surface is intentionally small (vector-feature hetero; stateless snapshot-loop temporal). Full TGN / TGAT-style recurrent memory and tensor-aware hetero classifiers are deferred.
  • Graph builders cover common structural patterns (grid, kNN, radius, IoU, fully connected, random); custom topology is up to the user. kNN / radius / IoU / fully-connected scale O(Nยฒ) and emit warnings on large N.
  • Patch helpers require exact-divisible dimensions by default; use padding="auto" to right-pad.
  • Multi-GPU. TGraphX provides DDP-aware helpers (tgraphx.distributed) and a single-process smoke example, not an automatic multi-GPU training framework.
  • Dashboard is a local-first lightweight monitor, not a TensorBoard replacement. See docs/dashboard.md.
  • Differentiability: all learned parameters are end-to-end differentiable; graph topology (edge_index) is user-supplied and not learned by the model.

GNN family coverage

GNN family Implemented? Tested? Limitations
Tensor-aware GCN-style (Conv message passing) โœ… ConvMessagePassing โœ… 2-D [N, C, H, W] and 3-D [N, C, D, H, W]; edge features must be matching-rank spatial with channel count = node channel count
Tensor-aware GAT (multi-head) โœ… TensorGATLayer โœ… 2-D and 3-D node features (spatial_rank=2/3); scalar attention per (edge, head); vector [E, D_e] and matching-rank spatial / volumetric edge features (mean-pooled); no per-pixel / per-voxel attention
Tensor-aware GraphSAGE โœ… TensorGraphSAGELayer โœ… 2-D and 3-D node features (spatial_rank=2/3); mean / max only; no LSTM aggregator
Tensor-aware GIN / GINEConv โœ… TensorGINLayer โœ… 2-D and 3-D node features (spatial_rank=2/3)
MPNN-style custom layer โœ… TensorMessagePassingLayer base โœ… (subclass test + example) โ€”
Edge-conditioned MP (spatial / volumetric) โœ… ConvMessagePassing, TensorGATLayer (mean-pooled), TensorGraphSAGELayer, TensorGINLayer โœ… edge features [E, C_e, H, W] (2-D) or [E, C_e, D, H, W] (3-D, matching the layer's spatial_rank)
Per-edge edge_weight ([E]) โœ… all four message-passing layers, 2-D and 3-D โœ… scales messages before aggregation; on GAT applied after softmax-normalised attention
3-D / volumetric node features โœ… ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer, TensorGINLayer โœ… [N, C, D, H, W]; pass spatial_rank=3 to GAT/SAGE/GIN, or (C, D, H, W) in_shape to ConvMessagePassing. DeepCNNAggregator is rank-aware. LinearMessagePassing covers vector [N, D] and is unaffected.
Edge-conditioned MP (vector) โœ… TensorGATLayer, TensorGraphSAGELayer, TensorGINLayer โœ… edge features [E, D_e]; edge_features_kind="vector"
aggr="sum"|"mean"|"max" base โœ… all three modes โœ… hand-computed + backward ConvMessagePassing aggr="max" routes through scatter_max
Graph Transformer (vector features) ๐Ÿงช GraphTransformerLayer (+ degree / Laplacian / adjacency-bias encodings) โœ… global multi-head self-attention [N, D]; O(Nยฒ); tensor-aware variant intentionally deferred
Heterogeneous graphs ๐Ÿงช HeteroGraph, HeteroGraphBatch, HeteroConv, HeteroGraphClassifier, HeteroNodeClassifier โœ… vector features; relation-dispatch wrapper + per-type classifier; tensor-aware hetero classifiers intentionally deferred
Temporal graphs ๐Ÿงช TemporalGraphSequence, TemporalGraphBatch, temporal_readout, TemporalGraphClassifier, TemporalGraphRegressor โœ… stateless snapshot-loop + readout; recurrent (TGN / TGAT) memory intentionally deferred
Sampling (homogeneous) โœ… induced_subgraph, edge_subgraph, k_hop_subgraph, sample_nodes, sample_edges, neighbor_sample, random_walk_sample โœ… All deterministic with seed; per-call generator (no global RNG side effects)
Sampling loaders โœ… SubgraphDataLoader, NeighborSamplerLoader โœ… Plain Python iterables; deterministic with seed
Sampling (hetero / temporal โ€” v0.2.8) โœ… hetero_induced_subgraph, hetero_neighbor_sample, temporal_window_sample, temporal_window_sample_batch โœ… Per-relation fanouts; window slicing on equal- and variable-length temporal batches
Distributed (DDP) helpers โœ… tgraphx.distributed (get_rank, get_world_size, is_rank_zero, rank_zero_print, rank_zero_only, barrier) โœ… Never auto-initialises DDP; safe to import in single-process / CPU-only contexts
Learned graph construction โœ… tgraphx.learned_graph โœ… soft adjacency, EdgeScorer (differentiable); top-k discrete (non-diff)
PyG / DGL converters (homogeneous + hetero) โœ… tgraphx.interop โœ… data converters only (lazy imports); not an API replacement
MLflowLogger โœ… tgraphx.tracking.MLflowLogger โœ… lazy mlflow import; opt-in via tgraphx[mlflow] extra
Arbitrary-rank tensor support across every layer โ›” Out of scope โ€” Only vector, 2-D, 3-D node-feature layouts ship โ€” see Scope boundaries

Project structure

TGraphX/
โ”œโ”€โ”€ tgraphx/
โ”‚   โ”œโ”€โ”€ __init__.py          # public API re-exports
โ”‚   โ”œโ”€โ”€ core/                # Graph, GraphBatch, hetero & temporal containers + batches
โ”‚   โ”‚   โ”œโ”€โ”€ graph.py         # Graph, GraphBatch
โ”‚   โ”‚   โ”œโ”€โ”€ graph_utils.py   # edge topology helpers
โ”‚   โ”‚   โ”œโ”€โ”€ dataloader.py    # GraphDataset, GraphDataLoader
โ”‚   โ”‚   โ”œโ”€โ”€ hetero_graph.py  # HeteroGraph (typed nodes + edges)
โ”‚   โ”‚   โ”œโ”€โ”€ hetero_batch.py  # HeteroGraphBatch
โ”‚   โ”‚   โ”œโ”€โ”€ temporal.py      # TemporalGraphSequence
โ”‚   โ”‚   โ”œโ”€โ”€ temporal_batch.py# TemporalGraphBatch
โ”‚   โ”‚   โ””โ”€โ”€ utils.py         # load_config, get_device
โ”‚   โ”œโ”€โ”€ layers/              # tensor-aware GNN layers + transformer + hetero/temporal
โ”‚   โ”‚   โ”œโ”€โ”€ base.py             # TensorMessagePassingLayer, LinearMessagePassing
โ”‚   โ”‚   โ”œโ”€โ”€ conv_message.py     # ConvMessagePassing
โ”‚   โ”‚   โ”œโ”€โ”€ attention_message.py# AttentionMessagePassing (legacy sigmoid)
โ”‚   โ”‚   โ”œโ”€โ”€ gat.py              # TensorGATLayer (multi-head, scalar / channel modes, chunked)
โ”‚   โ”‚   โ”œโ”€โ”€ sage.py             # TensorGraphSAGELayer
โ”‚   โ”‚   โ”œโ”€โ”€ gin.py              # TensorGINLayer / GINEConv
โ”‚   โ”‚   โ”œโ”€โ”€ graph_transformer.py# GraphTransformerLayer (vector)
โ”‚   โ”‚   โ”œโ”€โ”€ transformer_encodings.py # degree / Laplacian / adjacency-bias encodings
โ”‚   โ”‚   โ”œโ”€โ”€ hetero.py           # HeteroConv (relation-dispatch wrapper)
โ”‚   โ”‚   โ”œโ”€โ”€ hetero_readout.py   # hetero_mean/sum/max/concat_pool
โ”‚   โ”‚   โ”œโ”€โ”€ temporal_readout.py # temporal_readout (last/mean/max + mask)
โ”‚   โ”‚   โ”œโ”€โ”€ factory.py          # make_layer()
โ”‚   โ”‚   โ”œโ”€โ”€ aggregator.py       # DeepCNNAggregator
โ”‚   โ”‚   โ”œโ”€โ”€ safe_pool.py        # SafeMaxPool2d
โ”‚   โ”‚   โ””โ”€โ”€ _scatter.py / _dim.py  # internal helpers
โ”‚   โ”œโ”€โ”€ models/              # task heads + factories
โ”‚   โ”‚   โ”œโ”€โ”€ factory.py          # build_model(), build_model_from_config()
โ”‚   โ”‚   โ”œโ”€โ”€ graph_classifier.py # GraphClassifier
โ”‚   โ”‚   โ”œโ”€โ”€ node_classifier.py  # NodeClassifier
โ”‚   โ”‚   โ”œโ”€โ”€ edge_predictor.py   # EdgePredictor
โ”‚   โ”‚   โ”œโ”€โ”€ regressors.py       # NodeRegressor, GraphRegressor
โ”‚   โ”‚   โ”œโ”€โ”€ hetero_models.py    # HeteroGraphClassifier, HeteroNodeClassifier
โ”‚   โ”‚   โ”œโ”€โ”€ temporal_models.py  # TemporalGraphClassifier, TemporalGraphRegressor
โ”‚   โ”‚   โ”œโ”€โ”€ cnn_encoder.py      # CNNEncoder
โ”‚   โ”‚   โ”œโ”€โ”€ cnn_gnn_model.py    # CNN_GNN_Model
โ”‚   โ”‚   โ””โ”€โ”€ pre_encoder.py      # PreEncoder (optional ResNet-18)
โ”‚   โ”œโ”€โ”€ graph_builders.py    # grid / kNN / radius / IoU / fully-connected / random + patch helpers
โ”‚   โ”œโ”€โ”€ learned_graph.py     # soft adjacency, EdgeScorer, top-k edges
โ”‚   โ”œโ”€โ”€ interop.py           # PyG / DGL converters (lazy imports; homogeneous + hetero)
โ”‚   โ”œโ”€โ”€ sampling.py          # induced / edge / k-hop / neighbour / random-walk sampling
โ”‚   โ”œโ”€โ”€ sampling_loaders.py  # SubgraphDataLoader, NeighborSamplerLoader
โ”‚   โ”œโ”€โ”€ hetero_sampling.py   # hetero_induced_subgraph, hetero_neighbor_sample (v0.2.8)
โ”‚   โ”œโ”€โ”€ temporal_sampling.py # temporal_window_sample, temporal_window_sample_batch (v0.2.8)
โ”‚   โ”œโ”€โ”€ distributed.py       # rank-zero / world-size / barrier helpers
โ”‚   โ”œโ”€โ”€ training.py          # train_epoch, evaluate, fit, set_seed, checkpointing, metrics
โ”‚   โ”œโ”€โ”€ tracking.py          # CSVLogger, TensorBoardLogger, MLflowLogger, write_graph_stats
โ”‚   โ”œโ”€โ”€ performance.py       # env_report, recommended_device, estimate_message_memory
โ”‚   โ””โ”€โ”€ dashboard/           # local-first training dashboard (off by default)
โ”‚       โ”œโ”€โ”€ app.py           # DashboardServer, export_dashboard_html, CLI main()
โ”‚       โ”œโ”€โ”€ __init__.py      # launch_dashboard, launch_dashboard_background
โ”‚       โ””โ”€โ”€ static/          # dashboard.css, dashboard.js (packaged in wheel)
โ”œโ”€โ”€ tests/                   # 1230+ tests; CPU-safe, deterministic
โ”œโ”€โ”€ examples/                # 30+ self-contained demos (see "Examples" section)
โ”œโ”€โ”€ benchmarks/              # benchmark_layers.py / benchmark_graph_builders.py / benchmark_sampling.py
โ”œโ”€โ”€ docs/                    # API reference, limitations, roadmap, performance, โ€ฆ
โ”œโ”€โ”€ .github/workflows/       # tests.yml (Ubuntu full + macOS/Windows smoke), publish.yml
โ”œโ”€โ”€ pyproject.toml
โ”œโ”€โ”€ CHANGELOG.md
โ””โ”€โ”€ LICENSE

Development

# Install with dev dependencies (pytest, build, twine)
pip install -e ".[dev]"

# Run the test suite (CPU tests always run; CUDA/MPS skipped if unavailable)
pytest

# Run a specific test file
pytest tests/test_layers.py -v

# Run the examples
python examples/minimal_spatial_message_passing.py
python examples/minimal_graph_classifier.py

Authorship

Software author and maintainer: Arash Sajjadi, PhD Candidate in Computer Science, University of Saskatchewan (arash.sajjadi@usask.ca)

Academic supervision: Mark Eramian, PhD Supervisor / Academic Advisor, University of Saskatchewan

Related preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning Arash Sajjadi and Mark Eramian โ€” arXiv:2504.03953

The software package is developed and maintained by Arash Sajjadi. Mark Eramian is Arash Sajjadi's PhD supervisor and co-author of the related academic preprint. Software authorship and paper co-authorship are separate roles; both are acknowledged accurately above.


Citation

If you use TGraphX in your research, please cite:

@misc{sajjadi2025tgraphxtensorawaregraphneural,
      title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
      author={Arash Sajjadi and Mark Eramian},
      year={2025},
      eprint={2504.03953},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2504.03953},
}

License

TGraphX is released under the MIT License.


Questions, issues, or contributions are welcome โ€” please open a GitHub issue or pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tgraphx-0.2.9.tar.gz (369.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tgraphx-0.2.9-py3-none-any.whl (259.0 kB view details)

Uploaded Python 3

File details

Details for the file tgraphx-0.2.9.tar.gz.

File metadata

  • Download URL: tgraphx-0.2.9.tar.gz
  • Upload date:
  • Size: 369.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for tgraphx-0.2.9.tar.gz
Algorithm Hash digest
SHA256 60f02b3745c2a16b94fd74d1a3f12e670d32bddfbd12052340c2ffdd490fad6c
MD5 140d08a357f86730e9f96480368ed74e
BLAKE2b-256 0b978e9a66a340670df65406dcb67527e9d8a14826939d84995fe9d32ee169da

See more details on using hashes here.

File details

Details for the file tgraphx-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: tgraphx-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 259.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for tgraphx-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 5621ada4b935e579e0f1a28e1235e21fc0a4d57c8fb01d236e5a4909d099ef1a
MD5 18133bd56a437a40a2e9fcea4ff7973f
BLAKE2b-256 a2397c3008119155d642f40c46c09db35a97cade95dd5aca7f91eba377b7b9db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page