Tensor-aware graph neural networks preserving spatial node feature layouts
Project description
TGraphX
๐ Preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning ยท Sajjadi & Eramian, arXiv 2025
TGraphX is a PyTorch library for graph neural networks whose node features are multi-dimensional tensors โ such as [C, H, W] image-patch feature maps โ rather than flat vectors. Convolutional message passing operates directly on spatial feature maps at each node, so local structure is never destroyed by flattening.
Colab tutorial
TGraphX includes a hands-on Google Colab notebook that installs the latest PyPI release and walks through the main API features interactively:
Open the TGraphX Colab Tutorial โ
The notebook covers:
- CPU/GPU environment checks
- Vector node classification
- 2-D spatial graph classification using image patches
- 3-D volumetric graph classification
- Graph-level and node-level regression
- Edge prediction / edge classification
- Layer zoo:
ConvMessagePassing,TensorGATLayer,TensorGraphSAGELayer,TensorGINLayer,LinearMessagePassing, legacyAttentionMessagePassing - Edge weights and edge features
- Dashboard-compatible run files
Honesty note: All tasks in the notebook use controlled synthetic data designed to verify installation, API behaviour, device compatibility, and gradient flow. They are not benchmark results and make no real-world performance or state-of-the-art claims.
The problem TGraphX solves
Standard GNN frameworks (PyG, DGL) expect flat vector node features. Flattening a [C, H, W] feature map into a [CยทHยทW] vector discards the spatial structure that makes CNNs effective. TGraphX keeps each node's representation as a tensor and applies 1ร1 convolutions during message passing, so every neighbourhood aggregation step acts like a miniature CNN across neighbouring feature maps.
What is currently implemented
All "tensor-aware" GNN layers in this list operate on spatial node feature
maps [N, C, H, W] and preserve the spatial layout through message
passing. They are adaptations of the canonical algorithms โ not
drop-in clones of PyTorch Geometric's vector-feature implementations.
| Component | Class | Input node shape | Notes |
|---|---|---|---|
| Graph data structure | Graph, GraphBatch |
any | Validated; .to(device) |
| Tensor-aware GCN-style message passing | ConvMessagePassing |
[N, C, H, W] |
aggr="sum" or "mean"; 1ร1 conv messages + deep CNN aggregator |
| Tensor-aware GAT (multi-head) | TensorGATLayer |
[N, C, H, W] |
True GAT: softmax over incoming edges per destination, per head; scalar attention per (edge, head) |
| Tensor-aware GraphSAGE | TensorGraphSAGELayer |
[N, C, H, W] |
Separate self / neighbour 1ร1 Conv2d; mean or max aggregation; optional L2 normalise |
| Tensor-aware GIN / GINEConv | TensorGINLayer |
[N, C, H, W] |
(1+ฮต)ยทh_j + ฮฃ h_i; default 1ร1 Conv MLP, learnable ฮต, optional GINEConv edge term |
| Spatial-gating message passing (legacy) | AttentionMessagePassing |
[N, C, H, W] or [N, D] |
Per-edge sigmoid gating โ not GAT. Kept for backward compatibility; use TensorGATLayer for true GAT. |
| Vector message passing | LinearMessagePassing |
[N, D] |
Base layer with linear projections |
| Custom layer base class | TensorMessagePassingLayer |
any | Override message / update; base handles aggregation |
| CNN patch encoder | CNNEncoder |
[N, C_in, pH, pW] |
Outputs spatial feature maps [N, C_out, H', W'] |
| Optional pre-encoder | PreEncoder |
[N, C_in, pH, pW] |
Custom or pretrained ResNet-18 |
| Unified CNN-GNN model | CNN_GNN_Model |
[N, C, pH, pW] |
Takes pre-split patches; user supplies edge_index |
| Graph classification | GraphClassifier |
[N, C, H, W] |
Mean / sum / max readout |
| Node classification | NodeClassifier |
[N, D] |
Vector features only |
| Dataset & loader | GraphDataset, GraphDataLoader |
โ | Wraps torch.utils.data |
| Utilities | load_config, get_device |
โ | YAML/JSON config; CUDAโMPSโCPU |
Datasets, transforms, metrics, benchmarks (v0.2.9)
TGraphX includes a unified dataset registry, native synthetic datasets, folder-backed datasets, and optional adapters for torchvision, PyG, DGL, and OGB.
from tgraphx.datasets import list_datasets, dataset_info, get_dataset
print(list_datasets()) # registered names
ds = get_dataset("synthetic:patch_graph", num_graphs=32, seed=0)
g = ds[0]
- TGraphX does not redistribute third-party datasets โ no bundled
datasets ship in the wheel or sdist. External adapters delegate
download and parsing to the upstream library; downloads happen
only when the user passes
download=True. - Cache layout:
<root>(or$TGRAPHX_DATA, or~/.cache/tgraphx/datasets). Inspect withtgraphx.datasets.cache_summary(); clear withclear_cache(...). - Transforms (
tgraphx.transforms) are deterministic-when-seeded, non-mutating-by-default, and importing them does not pull in any optional dependency. - Metrics (
tgraphx.metrics) are pure-PyTorch. - Benchmarks (
benchmarks/benchmark_*.py) accept--smalland--output JSON; their results are engineering reproducibility signals, not real-world performance claims. - Synthetic datasets are sanity / tutorial / trainability datasets, not benchmarks.
Detailed write-ups in docs/datasets.md, docs/transforms.md, docs/metrics.md, docs/benchmarks.md, and the license/citation policy.
Optional dataset extras
| Adapter | Install |
|---|---|
| Native synthetic / folder | (no extra) |
| torchvision-backed datasets | (already a TGraphX base dependency) |
| PyG dataset adapter | pip install "tgraphx[pyg]" |
| DGL dataset adapter | follow upstream install (DGL wheels are platform-sensitive; we don't pin them) |
| OGB dataset adapter | pip install "tgraphx[ogb]" |
| Image folder dataset (PIL) | pip install "tgraphx[pillow]" |
Importing tgraphx, tgraphx.datasets, tgraphx.transforms, or
tgraphx.metrics does not import torch_geometric / dgl / ogb.
Current scope and boundaries
TGraphX is a focused library for tensor-aware patch-graph GNNs. Here is what is stable, what is experimental, and what is intentionally out of scope. Full details are in docs/limitations.md and docs/roadmap.md.
Optional and experimental capabilities
| Feature | Status | Install / usage |
|---|---|---|
TensorGATLayer(attention_mode="channel") |
๐งช Experimental | Constructor argument; per-channel attention |
TensorGATLayer(chunk_size=K) two-pass chunked forward |
โ Stable | forward(chunk_size=K) (log-sum-exp) |
GraphTransformerLayer (vector node features) |
๐งช Experimental | from tgraphx.layers.graph_transformer import GraphTransformerLayer |
| Positional / structural encodings (degree, Laplacian, adjacency bias) | ๐งช Experimental | tgraphx.layers.transformer_encodings |
HeteroGraph + HeteroGraphBatch + HeteroConv + classifiers |
๐งช Experimental | tgraphx.HeteroGraph, HeteroGraphBatch, tgraphx.layers.hetero.HeteroConv, tgraphx.models.hetero_models.* |
TemporalGraphSequence + TemporalGraphBatch + readout + classifiers |
๐งช Experimental | tgraphx.TemporalGraphSequence, TemporalGraphBatch, tgraphx.layers.temporal_readout, tgraphx.models.temporal_models.* |
| Subgraph / k-hop / neighbour / random-walk sampling | โ Stable | tgraphx.sampling, tgraphx.SubgraphDataLoader, tgraphx.NeighborSamplerLoader |
| Hetero / temporal sampling (v0.2.8) | โ Stable | tgraphx.hetero_induced_subgraph, tgraphx.hetero_neighbor_sample, tgraphx.temporal_window_sample, tgraphx.temporal_window_sample_batch |
| Distributed (DDP) helpers | โ Stable | tgraphx.distributed; never auto-initialises DDP |
MLflowLogger |
โ Opt-in | pip install "tgraphx[mlflow]"; lazy mlflow import |
| PyG / DGL data converters (homogeneous + hetero) | โ Opt-in | tgraphx.interop (lazy imports) |
| Learned graph helpers | โ Stable | tgraphx.learned_graph |
Patch helper padding="auto" |
โ Stable | image_to_patches(imgs, ps, padding="auto") |
| Hardware monitoring dashboard | ๐ Opt-in | pip install "tgraphx[monitoring]" |
| TensorBoard logging | ๐ Opt-in | pip install "tgraphx[tracking]" |
Scope boundaries (true design limits)
These are intentional, not bugs. Detailed write-ups live in docs/limitations.md; the roadmap is in docs/roadmap.md.
- Supported node feature ranks:
[N, D],[N, C, H, W], and[N, C, D, H, W]. A universal arbitrary-rank layer that works across every existing layer is a v0.3 design discussion. - GAT per-pixel / per-voxel attention: naive score tensors would
be
O(E ยท K ยท H ยท W)โ memory-prohibitive for typical spatial GNN workloads. Memory-safe variants (factorised / windowed / low-rank) are deferred until designed honestly. Per-channel attention is shipped asattention_mode="channel". - PyG / DGL drop-in compatibility: TGraphX is not a replacement.
tgraphx.interopships data converters; layer APIs differ. - Multi-GPU training framework: TGraphX provides rank-zero / world-size DDP helpers and a single-process smoke example. Production-grade automatic multi-GPU training is the user's responsibility (and is intentionally not bundled).
- Recurrent temporal memory (TGN / TGAT-style): temporal workflows use a stateless snapshot-loop pattern. Memory-aware temporal architectures are an open design question.
- Profiling and file writes: disabled by default; every logger, dashboard, and checkpoint write is opt-in.
Performance
Environment and hardware report
from tgraphx.performance import env_report, estimate_message_memory, recommended_device
print(env_report()) # Python/PyTorch/CUDA/MPS info
print(env_report(include_hardware=True)) # + CPU/RAM/CUDA memory (needs psutil)
print(env_report(include_sensors=True)) # + GPU util/temp (needs pynvml)
dev = recommended_device() # CUDA > MPS > CPU
# Estimate peak message-buffer memory before running
m = estimate_message_memory(num_edges=1024, out_shape=(64, 8, 8))
print(f"~{m['total_mb']:.1f} MB ({m['note']})")
Benchmarks
# Layer throughput (CPU-safe, all flags optional)
python benchmarks/benchmark_layers.py --layer gat --nodes 64 --edges 256 \
--shape 8,4,4 --device cpu --iters 10
# CUDA + AMP + backward
python benchmarks/benchmark_layers.py --layer conv --nodes 256 --edges 2048 \
--shape 32,8,8 --device cuda --amp 1 --backward 1
# Save JSON result
python benchmarks/benchmark_layers.py --layer gin --shape 8,4,4 \
--output results/gin.json
# Graph builder timing
python benchmarks/benchmark_graph_builders.py --small # CI-safe
python benchmarks/benchmark_graph_builders.py # full
torch.compile and AMP
python examples/torch_compile_benchmark.py # eager vs compiled, correctness check
python examples/mixed_precision_inference.py # autocast forward demo (finite-output check)
python examples/memory_report.py # env report + memory estimates
AMP policy:
| Backend | Recommended dtype | Status | Notes |
|---|---|---|---|
| CPU | bfloat16 | โ Tested | Covered by tests/test_amp_compile.py in full Linux CI |
| CUDA | float16 / bfloat16 | โ ๏ธ Best-effort | Behaviour fixed in v0.2.2; bfloat16 needs Ampere+ |
| MPS | โ | โ ๏ธ Best-effort | PyTorch operator coverage varies; not in CI |
v0.2.2 fixes: broadcast_edge_weight casts edge weights to activation dtype;
TensorGATLayer casts attention weights before index_add_; edge_softmax
upcasts to fp32 for numerical stability and casts back. See
docs/performance.md for full details.
Optional chunked forward (ConvMessagePassing)
Reduce peak edge-buffer memory by processing edges in chunks:
from tgraphx.layers.conv_message import ConvMessagePassing
layer = ConvMessagePassing(in_shape=(32, 8, 8), out_shape=(32, 8, 8), aggr="sum")
# Same output as unchunked; lower peak memory for large E
out = layer(x, edge_index, chunk_size=512)
Supported aggregations: "sum" and "mean". "max" falls back to the
standard path with a warning. All four message-passing layers accept
chunk_size in forward():
ConvMessagePassing: sum / mean (max falls back).TensorGraphSAGELayer: mean / max.TensorGINLayer: sum.TensorGATLayer: two-pass log-sum-exp. Output matches unchunked within float32 tolerance.
Hardware compatibility
| Platform | Forward | AMP | torch.compile | CI coverage | Notes |
|---|---|---|---|---|---|
| CPU | โ | โ ๏ธ bfloat16 only | โ | Full CI (Ubuntu) | Compile overhead may dominate small graphs |
| CUDA | โ | โ ๏ธ float16 (op-dependent) | โ | Local tests only | index_add_ ops require dtype match; no GPU runners in CI |
| MPS (Apple Silicon) | โ | โ ๏ธ limited | โ ๏ธ | Smoke CI (macOS) | Best-effort; PyTorch op coverage varies |
| Linux | โ | โ | โ | Full CI (ubuntu-latest, Py 3.10/3.11/3.12) | Primary CI platform |
| Windows | โ | โ | โ | Smoke CI (Py 3.11) | Imports + build + dashboard CLI smoke |
| macOS | โ | โ ๏ธ limited | โ ๏ธ | Smoke CI (Py 3.11) | Same surface as Windows smoke |
Training utilities
TGraphX includes lightweight training helpers โ not a full training framework. All logging and file writes are off by default.
Training loop helpers
from tgraphx.training import train_epoch, evaluate, fit
import torch.nn.functional as F
# One-line training loop
history = fit(
model, train_loader, val_loader=val_loader,
epochs=20,
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
loss_fn=F.cross_entropy,
log_level=1, # print per-epoch summary
)
# โ [{"epoch":0, "train_loss":0.9, "val_loss":0.85}, ...]
Supported batch formats: GraphBatch (with graph_labels / node_labels)
and (Tensor, Tensor) tuples.
Standalone helpers
from tgraphx.training import (
set_seed, # seeds torch / numpy / random
count_parameters, # trainable parameter count
save_checkpoint, # torch.save wrapper
load_checkpoint, # returns saved epoch number
accuracy, # multi-class argmax accuracy
mean_absolute_error, mean_squared_error,
)
CSV logging (dashboard-compatible)
from tgraphx.tracking import CSVLogger
with CSVLogger("runs/my_run") as logger:
history = fit(model, train_loader, ..., logger=logger)
# writes runs/my_run/metrics.csv with UTC timestamps
TensorBoard logging (optional)
from tgraphx.tracking import TensorBoardLogger # lazy import
# pip install tensorboard or pip install "tgraphx[tracking]"
with TensorBoardLogger("runs/tb") as tb:
history = fit(model, train_loader, ..., logger=tb)
Nothing is written unless you explicitly pass a logger.
MLflow logging (optional)
from tgraphx.tracking import MLflowLogger # pip install mlflow
with MLflowLogger(run_name="my_run", experiment="gnn") as mlf:
history = fit(model, train_loader, logger=mlf, ...)
Dashboard
TGraphX includes a local training dashboard โ off by default, zero external dependencies, no telemetry.
Quick start
# Launch (localhost only, no token needed)
tgraphx-dashboard --logdir runs/demo
# โ http://127.0.0.1:8765
# LAN access โ explicit token required
tgraphx-dashboard --logdir runs/demo \
--host 0.0.0.0 --token MY_SECRET_TOKEN
# Auto-generated token
tgraphx-dashboard --logdir runs/demo --host 0.0.0.0 --token auto
# Offline HTML snapshot โ no server needed
tgraphx-dashboard --logdir runs/demo --export-html snapshot.html
# Python API โ non-blocking background thread
from tgraphx.dashboard import launch_dashboard_background, export_dashboard_html
server = launch_dashboard_background("runs/demo", port=8765)
# ... training loop ...
server.shutdown()
# Offline snapshot
export_dashboard_html("runs/demo", "snapshot.html")
Dashboard features
| Section | Contents |
|---|---|
| Overview | Status chip, epoch progress, live loss, elapsed / ETA |
| Metrics | SVG line charts, window selector, EMA smoothing, per-chart CSV/SVG export |
| Graph | Graph summary, degree stats, graph_stats.json precomputed cards, SVG preview |
| Hardware | CPU/RAM/GPU/CUDA/MPS, power draw, thermal status (optional psutil/pynvml) |
| Logs | Scrollable metric table with CSV export |
| Config | run_metadata.json rendered safely |
| Tools | Copy URL, export buttons, refresh controls |
| TV mode | Full-screen large-font passive monitoring |
Key features:
- Incremental updates โ browser requests only new rows via
?since_row=N - Multi-run selector โ point at a parent directory; select runs by name
- Color-blind-safe palette โ Okabe-Ito toggle, persisted in localStorage
- Accessible โ skip link, ARIA labels, focus-visible, reduced-motion support
- Export โ metrics CSV, per-chart CSV/SVG, print/save PDF, offline HTML snapshot
- Responsive โ phone, tablet, desktop, TV/large-monitor layouts
- Pause/resume polling, configurable refresh interval
Security model
| Scenario | Token required? |
|---|---|
--host 127.0.0.1 (default) |
No |
--host 0.0.0.0 + connecting from localhost |
No |
--host 0.0.0.0 + connecting from another device |
Yes |
Starting LAN mode without --token |
Refused at startup |
Read-only ยท no external CDN ยท no telemetry ยท no token leakage in API responses ยท path-traversal protected.
Log files
metrics.csv โ epoch,train_loss,val_loss,... (ISO-8601 UTC timestamp column)
run_metadata.json โ run name, status, total_epochs, device, task (free-form dict)
graph_metadata.json โ optional graph summary + edge_index for preview (โค200 nodes)
graph_stats.json โ optional precomputed stats (write with write_graph_stats())
from tgraphx import write_graph_stats
write_graph_stats({"num_nodes": 100, "num_edges": 400, "density": 0.04},
"runs/demo/graph_stats.json")
Hardware monitoring (optional)
pip install "tgraphx[monitoring]" # psutil + pynvml
Missing packages show a compact "unavailable" reason per row โ no broken charts.
Privacy and local-first behavior
TGraphX is designed to be entirely local and private:
| Behavior | Default |
|---|---|
| Telemetry / analytics | None โ never |
| Remote calls at import | None |
| Dashboard | Off โ launch explicitly |
| CSV metric logging | Off โ create CSVLogger explicitly |
| TensorBoard logging | Off โ create TensorBoardLogger explicitly |
| Hardware monitoring | Off โ pass include_hardware=True to env_report |
| Checkpoints | Off โ call save_checkpoint explicitly |
| Graph serialization | Off โ include edge_index in JSON manually |
| Background threads | None (unless launch_dashboard_background is called) |
| File writes | Only to paths you explicitly provide |
| Reads | Dashboard reads only inside --logdir |
| External CDN / assets | None โ dashboard is fully self-contained |
No ~/.tgraphx directory or user-level config is created by default.
Installation
pip install tgraphx
Optional extras:
pip install "tgraphx[tracking]" # TensorBoard integration
pip install "tgraphx[monitoring]" # psutil + pynvml (dashboard hardware panel)
pip install "tgraphx[dev]" # pytest, build, twine
For a specific PyTorch build (e.g. CPU-only or a particular CUDA version), install PyTorch before TGraphX:
# CPU-only example
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install tgraphx
Install from source:
git clone https://github.com/arashsajjadi/TGraphX.git
cd TGraphX
pip install -e ".[dev]"
See pytorch.org for GPU-specific install commands.
A Conda environment file is provided at environment.yml.
Quickstart
Vector node features (simplest case):
import torch
from tgraphx import Graph, LinearMessagePassing, build_model, fit
# 8 nodes, 32-dimensional vector features
x = torch.randn(8, 32)
src = torch.arange(8)
edge_index = torch.stack([src, (src + 1) % 8])
g = Graph(x, edge_index)
layer = LinearMessagePassing(in_shape=(32,), out_shape=(64,))
out = layer(g.node_features, g.edge_index) # [8, 64]
out.sum().backward()
Spatial node features โ [C, H, W] preserved through message passing:
import torch
from tgraphx import Graph, ConvMessagePassing
N, C, H, W = 6, 16, 8, 8
node_features = torch.randn(N, C, H, W)
src = torch.arange(N)
edge_index = torch.stack([src, (src + 1) % N])
g = Graph(node_features, edge_index)
layer = ConvMessagePassing(in_shape=(C, H, W), out_shape=(32, H, W))
out = layer(g.node_features, g.edge_index) # [6, 32, 8, 8]
out.sum().backward()
Graph Builders
TGraphX ships pure-PyTorch graph builders that return edge_index
tensors ([2, E], dtype=torch.long) ready for any GNN layer or
Graph constructor. They create fixed, rule-based adjacency
structures โ they do not implement learned adjacency.
Builder API
| Function | Description | Key params |
|---|---|---|
build_grid_graph(rows, cols) |
2-D 4-connected grid | directed, self_loops |
build_grid_graph_3d(depth, rows, cols) |
3-D 6-connected grid | directed, self_loops |
build_fully_connected_graph(num_nodes) |
Complete graph | self_loops |
build_knn_graph(coords, k) |
k-nearest-neighbour | directed, self_loops |
build_radius_graph(coords, radius) |
All pairs within radius | directed, self_loops |
build_iou_graph(boxes, threshold) |
Bounding-box IoU โฅ threshold | directed, self_loops |
build_random_graph(num_nodes, num_edges) |
Uniform random sample | directed, self_loops, seed |
O(Nยฒ) warning:
build_knn_graphandbuild_radius_graphcalltorch.cdistinternally, which requires O(Nยฒ) time and memory. For large graphs (N > 10 000), use approximate-NN libraries instead.O(Nยฒ) warning:
build_fully_connected_graphemits Nยท(Nโ1) edges. Memory grows quadratically with node count.
Grid graph quickstart
import torch
from tgraphx import Graph, build_grid_graph
from tgraphx.layers.conv_message import ConvMessagePassing
# 3ร3 patch grid โ 9 nodes, each with a [4, 8, 8] spatial feature
node_features = torch.randn(9, 4, 8, 8)
edge_index = build_grid_graph(3, 3, directed=False, self_loops=True)
# edge_index: [2, 33] (24 neighbour + 9 self-loop edges)
g = Graph(node_features, edge_index)
layer = ConvMessagePassing(in_shape=(4, 8, 8), out_shape=(8, 8, 8))
out = layer(g.node_features, g.edge_index) # [9, 8, 8, 8]
2-D image patch graph
import torch
from tgraphx import build_grid_graph, image_to_patches
from tgraphx.layers.gat import TensorGATLayer
# Extract 4ร4 patches from a [B, C, H, W] image
images = torch.randn(2, 3, 8, 8)
patches = image_to_patches(images, patch_size=4) # [2, 4, 3, 4, 4]
# Build the patch grid graph
edge_index = build_grid_graph(2, 2, directed=False, self_loops=True)
# Run GAT on one image's patches
x = patches[0] # [4, 3, 4, 4]
gat = TensorGATLayer(in_channels=3, out_channels=8, num_heads=2, spatial_rank=2)
out = gat(x, edge_index) # [4, 8, 4, 4]
3-D volume patch graph
import torch
from tgraphx import build_grid_graph_3d, volume_to_patches
from tgraphx.layers.gat import TensorGATLayer
# Extract 4ร4ร4 patches from a [B, C, D, H, W] volume
volumes = torch.randn(1, 2, 8, 8, 8)
patches = volume_to_patches(volumes, patch_size=4) # [1, 8, 2, 4, 4, 4]
# Build the 3-D patch grid graph
edge_index = build_grid_graph_3d(2, 2, 2, directed=False, self_loops=True)
# Run GAT on one volume's patches
x = patches[0] # [8, 2, 4, 4, 4]
gat = TensorGATLayer(in_channels=2, out_channels=4, num_heads=2, spatial_rank=3)
out = gat(x, edge_index) # [8, 4, 4, 4, 4]
Patch helper API
| Function | Input | Output | Notes |
|---|---|---|---|
patch_grid_shape(H, W, patch_size, stride) |
โ | (n_h, n_w) |
Raises if not exactly covered |
image_to_patches(images, patch_size, stride) |
[B, C, H, W] |
[B, P, C, ph, pw] |
Row-major; matches grid node order |
volume_patch_grid_shape(D, H, W, patch_size, stride) |
โ | (n_d, n_h, n_w) |
Raises if not exactly covered |
volume_to_patches(volumes, patch_size, stride) |
[B, C, D, H, W] |
[B, P, C, pd, ph, pw] |
Depth-row-col order; matches 3-D grid |
Concept: tensor-aware node features
In a standard GNN a node carries a vector x_i โ โ^d. In TGraphX a node carries a tensor X_i โ โ^{CรHรW}. Every layer follows the standard message-passing template:
M_{iโj} = ฯ( X_i, X_j, E_{iโj} ) # per-edge messages
A_j = AGG_{i โ N(j)} M_{iโj} # permutation-invariant aggregation
X'_j = ฯ( X_j, A_j ) # update
Each layer instantiates this with its own (ฯ, AGG, ฯ):
| Layer | ฯ (message) |
AGG |
ฯ (update) |
|---|---|---|---|
ConvMessagePassing |
Conv1ร1(Concat(X_i, X_j[, E_ij])) |
sum / mean / max |
DeepCNNAggregator(A_j) (+ optional residual) |
TensorGATLayer |
ฮฑ_{ij}^k ยท W^k X_i with ฮฑ_{ij}^k = softmax_i(LeakyReLU(a_dstยทpool(W^k X_j) + a_srcยทpool(W^k X_i) + b^k(e_ij))) |
sum (weighted) | concat or mean over heads, optional residual |
TensorGraphSAGELayer |
W_neigh(X_i) (+ optional spatial cat or vector bias from e_ij) |
mean / max |
W_self(X_j) + AGG, optional L2 normalise |
TensorGINLayer |
X_i or ReLU(X_i + ฯ_e(e_ij)) (GINEConv) |
sum |
MLP((1+ฮต)ยทX_j + ฮฃ_i M_ij) |
LinearMessagePassing |
Linear(Concat(x_i, x_j[, e_ij])) |
sum / mean / max |
identity (override to customise) |
Spatial dimensions H and W are preserved through every learned transform. All aggregations are permutation-invariant over the order of incoming edges, and every layer is permutation-equivariant over node reindexing (verified by tests/test_math.py).
The graph structure โ which nodes are connected โ is not learned by the model. Users supply edge_index based on domain knowledge (e.g., spatial proximity of patches, IoU overlap of bounding boxes, kNN on patch centres).
Edge feature formats per layer
| Layer | Vector [E, D_e] |
Spatial [E, C_e, H, W] |
|---|---|---|
ConvMessagePassing |
โ | โ (concatenated along channels; channel count must equal node channel count) |
TensorGATLayer |
โ (additive attention bias on logits) | โ ๏ธ accepted; mean-pooled to scalar attention bias (no per-pixel attention) |
TensorGraphSAGELayer |
โ (additive channel bias post-W_neigh) |
โ (concatenated to source) |
TensorGINLayer |
โ (broadcast bias before ReLU) | โ (1ร1 Conv2d projection) |
TensorGATLayer spatial edge features:
[E, C_e, H, W](or[E, C_e, D, H, W]for 3-D nodes) are accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do not need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured GAT) raisesNotImplementedError. UseTensorGraphSAGELayerorTensorGINLayerfor full spatial edge-feature processing (no pooling).
Factory API
make_layer โ create any layer by name
from tgraphx import make_layer
# 2-D spatial GAT with 2 attention heads
layer = make_layer("gat", in_shape=(8, 4, 4), out_shape=(16, 4, 4), heads=2, residual=True)
# Vector linear message passing
layer = make_layer("linear", in_shape=(32,), out_shape=(64,), aggr="mean")
| Name | Layer class | Shape support |
|---|---|---|
"conv" |
ConvMessagePassing |
2-D / 3-D spatial |
"gat" |
TensorGATLayer |
2-D / 3-D spatial |
"sage" |
TensorGraphSAGELayer |
2-D / 3-D spatial |
"gin" |
TensorGINLayer |
2-D / 3-D spatial |
"linear" |
LinearMessagePassing |
vector only |
"legacy_attention" |
AttentionMessagePassing |
vector / 2-D spatial |
build_model โ complete task model
from tgraphx import build_model
# Node classification on vector features
model = build_model(
task="node_classification",
layer="linear",
in_shape=(32,),
hidden_shape=(64,),
num_layers=3,
num_classes=5,
)
out = model(x, edge_index) # [N, 5]
# Graph classification on 2-D image patches
model = build_model(
task="graph_classification",
layer="gat",
in_shape=(3, 4, 4), # [C, ph, pw]
hidden_shape=(8, 4, 4),
num_layers=2,
num_classes=10,
heads=2,
pooling="mean",
)
out = model(x, edge_index, batch=batch) # [G, 10]
Task support matrix
| Task | Vector [N,D] |
2-D spatial [N,C,H,W] |
3-D volumetric [N,C,D,H,W] |
|---|---|---|---|
node_classification |
โ | โ | โ |
node_regression |
โ | โ | โ |
graph_classification |
โ | โ | โ |
graph_regression |
โ | โ | โ |
edge_prediction |
โ | โ (spatial pool) | โ (spatial pool) |
link_prediction |
deferred | deferred | deferred |
Spatial / volumetric tasks: after the GNN stack the factory applies global spatial average-pooling to flatten
[N, C, *spatial]โ[N, C]before the linear head. Spatial resolution is preserved inside every GNN layer and only collapsed at the final readout.
build_model_from_config โ config-driven construction
from tgraphx import build_model_from_config
# From a Python dict (no eval, no exec)
config = {
"model": {
"task": "graph_classification",
"layer": "gat",
"in_shape": [8, 4, 4],
"hidden_shape": [16, 4, 4],
"num_layers": 2,
"num_classes": 3,
"heads": 2,
"residual": True,
"dropout": 0.1,
}
}
model = build_model_from_config(config)
# From a JSON file
model = build_model_from_config("config.json")
# From a YAML file (requires PyYAML)
model = build_model_from_config("config.yaml")
Examples
Run every fast example in one go:
python examples/run_all_fast_examples.py
Individual demos:
# Tensor-aware GCN-style spatial message passing
python examples/minimal_spatial_message_passing.py
# Graph classification โ short training loop with synthetic data
python examples/minimal_graph_classifier.py
# Tensor-aware multi-head GAT (attention weights sum to 1 per destination per head)
python examples/tensor_gat_minimal.py
# Tensor-aware GraphSAGE (mean / max / with-edge-features variants)
python examples/tensor_graphsage_minimal.py
# Custom user-defined message-passing layer subclass
python examples/custom_message_passing.py
# Trainability sanity: tiny overfit per GNN family
python examples/tiny_overfit_tensor_gat.py
python examples/tiny_overfit_edge_features.py
# Deep 8-layer stack gradient sanity
python examples/gradient_sanity_stack.py
# Graph builders + patch helpers
python examples/directed_vs_undirected_graphs.py
python examples/image_patch_graph.py
python examples/volume_patch_graph.py
python examples/gnn_family_with_graph_builders.py
# Factory + model examples (numbered)
python examples/01_vector_node_classification.py
python examples/02_spatial_graph_classification.py
python examples/03_volumetric_graph_classification.py
python examples/04_config_based_model.py
python examples/05_edge_prediction.py
# Hetero / temporal / sampling / transformer demos
python examples/hetero_graph_batch_demo.py
python examples/hetero_graph_classifier_demo.py
python examples/temporal_graph_batch_demo.py
python examples/temporal_graph_classifier_demo.py
python examples/neighbor_sampling_demo.py
python examples/sampling_demo_v028.py # random-walk + hetero + temporal sampling (v0.2.8)
python examples/graph_transformer_demo.py
python examples/gat_chunking_demo.py
python examples/v024_new_features.py
# Distributed-training helpers (single-process smoke)
python examples/ddp_training_smoke.py
# Performance / hardware
python examples/memory_report.py
python examples/mixed_precision_inference.py
python examples/torch_compile_benchmark.py
# Training + dashboard
python examples/training_minimal_fit.py
python examples/training_with_csvlogger.py
python examples/training_with_tensorboard.py
python examples/training_with_dashboard.py
python examples/checkpoint_save_load.py
API reference
Graph
from tgraphx import Graph
g = Graph(
node_features, # torch.Tensor [N, ...] required
edge_index=None, # torch.LongTensor [2, E] or None optional
edge_weight=None, # torch.Tensor [E] optional
edge_features=None, # torch.Tensor [E, ...] optional
node_labels=None, # torch.Tensor [N, ...] optional
edge_labels=None, # torch.Tensor [E, ...] optional
graph_label=None, # torch.Tensor (any shape) optional
metadata=None, # dict optional
)
g.clone() # deep copy (tensors and metadata)
g.to("cuda", dtype=torch.float32) # moves all tensor fields; dtype only
# applies to floating-point tensors
g.cpu(); g.cuda() # convenience aliases
g.add_self_loops(); g.remove_self_loops()
g.make_undirected(reduce="mean") # symmetrize, coalesce duplicates
g.is_undirected() # structural check (ignores weights)
g.validate() # re-run validation after manual mutation
g.num_nodes, g.num_edges, g.feature_shape, g.edge_feature_shape
g.has_edges, g.has_edge_weight, g.has_edge_features
Supported per-node feature layouts: [N, D], [N, C, H, W], and storage-level
[N, C, D, H, W]. Edge features mirror this: [E, D_e], [E, C_e, H, W], and
storage-level [E, C_e, D, H, W].
Graph.__init__ raises immediately on:
- non-Tensor inputs
node_featuresoredge_featureswith fewer than 2 dimensionsedge_indexwith wrong shape, wrong dtype (torch.longrequired), or out-of-range indicesedge_weightthat is not 1-D, or whose length differs fromE- per-edge tensors (
edge_weight/edge_features/edge_labels) supplied without anedge_index - device mismatch between
node_featuresand any other tensor field - length mismatch between
node_features/edge_indexand any per-node / per-edge tensor
GraphBatch
from tgraphx import GraphBatch
batch = GraphBatch([g1, g2, g3])
# batch.node_features [N_total, ...]
# batch.edge_index [2, E_total] โ indices offset per graph
# batch.edge_weight [E_total] โ concatenated if every graph has it
# batch.edge_features [E_total, ...] โ concatenated if every graph has it
# batch.node_labels [N_total, ...] โ concatenated if every graph has it
# batch.edge_labels [E_total, ...] โ concatenated if every graph has it
# batch.graph_labels [B, ...] โ stacked if every graph has it
# batch.metadata list[Any] โ verbatim, length B (None entries OK)
# batch.batch [N_total] โ graph membership (dtype=long)
batch.to("cuda")
All graphs must share the same per-node feature shape (and per-edge feature
shape, when present). Passing graphs with different spatial sizes raises a
ValueError with a descriptive message. Optional per-edge tensors must be
present on every graph that has edges, or none โ mixing-some-with-none is
rejected so per-edge data is never silently dropped.
ConvMessagePassing
from tgraphx.layers import ConvMessagePassing
layer = ConvMessagePassing(
in_shape=(C, H, W), # tuple: per-node input shape (spatial only)
out_shape=(C_out, H, W), # H and W must stay equal to in_shape's H, W
aggr="sum", # "sum" (default) | "mean" | "max"
use_edge_features=False, # set True to concatenate edge tensors into messages
aggregator_params=None, # dict forwarded to DeepCNNAggregator; e.g.
# {"num_layers": 2, "dropout_prob": 0.1}
residual=False, # add skip connection when in_shape == out_shape
)
out = layer(node_features, edge_index) # [N, C_out, H, W]
out = layer(node_features, edge_index, edge_features) # with edge features
aggr="max"is supported viascatter_reduce_(reduce='amax'). Whenchunk_sizeis also set,aggr="max"falls back to the unchunked path with awarnings.warn. UseGraphClassifier(pooling="max")for graph-level max readout.
AttentionMessagePassing
from tgraphx.layers import AttentionMessagePassing
# Spatial path
layer = AttentionMessagePassing(in_shape=(C, H, W), out_shape=(C_out, H, W))
# Vector path (also supported)
layer = AttentionMessagePassing(in_shape=(D,), out_shape=(D_out,))
out = layer(node_features, edge_index) # [N, C_out, H, W] or [N, D_out]
Important: This layer computes attn = sigmoid(qยทk / โd) independently per edge. Attention weights are not normalised over each destination node's neighbourhood (no softmax). This differs from standard GAT. For true GAT, use TensorGATLayer below.
TensorGATLayer (true multi-head GAT)
True GAT-style attention adapted to spatial features. For every destination
node j and every head k, attention weights satisfy ฮฃ_i ฮฑ_ij^k = 1.
from tgraphx.layers import TensorGATLayer
# 4 heads ร 8 channels each = 32 output channels (heads concatenated)
layer = TensorGATLayer(
in_channels=16,
out_channels=32, # divisible by num_heads when concat_heads=True
num_heads=4,
concat_heads=True, # False โ average heads, output is per-head channels
negative_slope=0.2, # LeakyReLU before edge softmax
attn_dropout=0.0, # dropout on attention weights (training only)
residual=True, # auto 1ร1 projection if in/out channels differ
bias=True,
add_self_loops=False, # True ensures every node has at least 1 in-edge
use_edge_features=False, # set True to enable EGAT-style vector edge bias
edge_dim=None, # required when use_edge_features=True
)
out = layer(x, edge_index) # [N, 32, H, W]
# Inspect attention weights (e.g. for visualisation or testing):
out, attn = layer(x, edge_index, return_attention=True)
# attn shape: [E, num_heads]; sums to 1 over incoming edges per destination per head.
# EGAT-style vector edge attention bias (e.g. relative box coords, IoU,
# distances, scale ratio):
layer_e = TensorGATLayer(
in_channels=16, out_channels=32, num_heads=4,
use_edge_features=True, edge_dim=3,
)
out = layer_e(x, edge_index, edge_features=ef) # ef: [E, 3]
Attention is scalar per (edge, head) in this implementation: the
projected query and key feature maps are mean-pooled over H ร W before
being scored, while the value tensors keep their full spatial layout
during aggregation. Per-pixel and per-channel attention modes are not yet
supported.
Spatial edge features ([E, C_e, H, W] for spatial_rank=2;
[E, C_e, D, H, W] for spatial_rank=3) are accepted: spatial dims are
mean-pooled to a channel vector before the per-(edge, head) attention bias
projection (spatial dims need not match node spatial dims). Use
TensorGraphSAGELayer or TensorGINLayer for full spatial edge-feature
processing without pooling.
TensorGraphSAGELayer
Tensor-aware GraphSAGE: h_j' = W_self(h_j) + W_neigh(AGG_i h_i).
from tgraphx.layers import TensorGraphSAGELayer
layer = TensorGraphSAGELayer(
in_channels=16,
out_channels=32,
aggr="mean", # "mean" or "max"
normalize=False, # True โ L2-normalise output channel vector per pixel
bias=True,
residual=False,
use_edge_features=False,
edge_dim=None, # required when use_edge_features=True
edge_features_kind="spatial", # or "vector" โ see below
)
out = layer(x, edge_index) # [N, 32, H, W]
# Spatial edge features [E, edge_dim, H, W] โ concatenated to source.
layer_s = TensorGraphSAGELayer(
in_channels=16, out_channels=32,
use_edge_features=True, edge_dim=4, edge_features_kind="spatial",
)
out = layer_s(x, edge_index, edge_features=ef_spatial)
# Vector edge features [E, edge_dim] โ projected to channel bias and added
# to W_neigh(h_src) before aggregation.
layer_v = TensorGraphSAGELayer(
in_channels=16, out_channels=32,
use_edge_features=True, edge_dim=3, edge_features_kind="vector",
)
out = layer_v(x, edge_index, edge_features=ef_vector) # ef_vector: [E, 3]
Isolated nodes (no incoming edges) receive only the self transform โ the neighbour aggregate is zero.
TensorGINLayer
Tensor-aware GIN / GINEConv: h_j' = MLP((1+ฮต)ยทh_j + ฮฃ_i m_ij).
from tgraphx.layers import TensorGINLayer
# Default 1ร1 Conv MLP (preserves spatial layout)
layer = TensorGINLayer(
in_channels=16,
out_channels=32,
hidden_channels=24, # defaults to out_channels
eps=0.0,
train_eps=False, # set True to make ฮต a learnable scalar parameter
use_batchnorm=False,
)
out = layer(x, edge_index) # [N, 32, H, W]
# Custom MLP (any nn.Module mapping [N, in_channels, H, W] โ [N, out_channels, H, W])
import torch.nn as nn
custom_mlp = nn.Sequential(
nn.Conv2d(16, 24, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(24, 32, kernel_size=1),
)
layer = TensorGINLayer(in_channels=16, out_channels=32, mlp=custom_mlp)
# GINEConv-style spatial edge inclusion: messages = ReLU(h_src + ฯ(e_ij))
layer_s = TensorGINLayer(
in_channels=16, out_channels=32,
use_edge_features=True, edge_dim=4, edge_features_kind="spatial",
)
out = layer_s(x, edge_index, edge_features=ef_spatial)
# Vector edge features [E, edge_dim] โ projected to [E, in_channels, 1, 1]
# and broadcast over H ร W before ReLU.
layer_v = TensorGINLayer(
in_channels=16, out_channels=32,
use_edge_features=True, edge_dim=3, edge_features_kind="vector",
)
out = layer_v(x, edge_index, edge_features=ef_vector) # ef_vector: [E, 3]
Custom layers via TensorMessagePassingLayer
import torch, torch.nn as nn
from tgraphx.layers import TensorMessagePassingLayer
class MyConv(TensorMessagePassingLayer):
def __init__(self, c_in, c_out):
super().__init__(in_shape=(c_in,), out_shape=(c_out,), aggr="mean")
self.W_g = nn.Conv2d(c_in, c_out, kernel_size=1)
self.W_v = nn.Conv2d(c_in, c_out, kernel_size=1)
def message(self, src, dest, edge_attr):
gate = torch.sigmoid(self.W_g(src + dest))
return gate * self.W_v(src)
def update(self, node_feature, aggregated_message):
return aggregated_message
The base class handles per-edge gather and aggregation (sum or mean)
for arbitrary trailing tensor shapes. See
examples/custom_message_passing.py.
CNNEncoder
from tgraphx.models import CNNEncoder
enc = CNNEncoder(
in_channels=3,
out_features=64,
num_layers=3, # total Conv2d blocks
hidden_channels=64,
dropout_prob=0.3,
use_batchnorm=True,
use_residual=True, # residual skip in intermediate blocks
pool_layers=1, # how many blocks include SafeMaxPool2d(2)
return_feature_map=True, # True โ [N, out_features, H', W']
# False โ [N, out_features] (global avg pool)
pre_encoder=None, # optional PreEncoder instance
)
features = enc(patches) # patches: [N, in_channels, patch_H, patch_W]
GraphClassifier
from tgraphx.models import GraphClassifier
clf = GraphClassifier(
in_shape=(C, H, W),
hidden_shape=(C_hidden, H, W),
num_classes=5,
num_layers=2,
aggr="sum",
pooling="mean", # "mean" | "sum" | "max"
)
logits = clf(
node_features, # [N, C, H, W]
edge_index, # [2, E]
batch=batch_vector, # [N] โ required for graph-level output
edge_features=None, # optional
) # โ [num_graphs, num_classes]
NodeClassifier
from tgraphx.models import NodeClassifier
nc = NodeClassifier(
in_shape=(64,), # vector features only
hidden_shape=(128,),
num_classes=3,
num_layers=2,
)
logits = nc(node_features, edge_index) # [N, num_classes]
CNN_GNN_Model
A full CNN โ GNN โ classify pipeline that accepts pre-split node patches.
from tgraphx.models import CNN_GNN_Model
model = CNN_GNN_Model(
cnn_params=dict(
in_channels=3,
out_features=64,
num_layers=2,
hidden_channels=64,
dropout_prob=0.0,
use_batchnorm=False,
use_residual=False,
pool_layers=1,
return_feature_map=True,
),
gnn_in_dim=(64, 8, 8), # must match CNN output shape exactly
gnn_hidden_dim=(64, 8, 8),
num_classes=10,
num_gnn_layers=2,
gnn_dropout=0.3, # forwarded to DeepCNNAggregator
residual=True, # per-layer skip connection
skip_cnn_to_classifier=False,
)
# raw_patches: pre-split by the user; shape [N, in_channels, pH, pW]
logits = model(raw_patches, edge_index) # [N, num_classes] (node-level)
logits = model(raw_patches, edge_index, batch=b) # [G, num_classes] (graph-level)
get_device
from tgraphx.core.utils import get_device
device = get_device() # CUDA (if available) โ MPS โ CPU
device = get_device(device_id=1) # specific CUDA device
Shape conventions
| Tensor | Shape | dtype | Notes |
|---|---|---|---|
node_features |
[N, D], [N, C, H, W], or [N, C, D, H, W] |
float | vector, 2-D spatial, or 3-D volumetric; N = number of nodes |
node_features (vector) |
[N, D] |
float | For NodeClassifier / LinearMessagePassing |
edge_index |
[2, E] |
torch.long |
Row 0 = source nodes, row 1 = destination nodes |
edge_features |
[E, ...] |
float | Optional; length must equal E |
batch |
[N] |
torch.long |
Maps each node to its graph index |
Device support
| Device | Status |
|---|---|
| CPU | โ Tested in full CI (Ubuntu 3.10 / 3.11 / 3.12) |
| NVIDIA CUDA | โ Locally tested (PyTorch 2.x); no GPU runners in CI |
| Apple Silicon MPS | โ ๏ธ Best-effort; macOS smoke CI imports + builds only |
| Multi-GPU (DDP) | ๐งฐ Helpers shipped (tgraphx.distributed); full automatic DDP training framework intentionally out of scope |
from tgraphx.core.utils import get_device
device = get_device()
model.to(device)
g.to(device)
batch.to(device)
Supported Python and PyTorch versions
| Python | PyTorch | Status |
|---|---|---|
| 3.10 | โฅ 1.13 | โ CI (ubuntu-latest) |
| 3.11 | โฅ 1.13 | โ CI (ubuntu-latest, macos-latest, windows-latest) |
| 3.12 | โฅ 1.13 | โ CI (ubuntu-latest) |
| 3.9 | โฅ 1.13 | Listed in classifiers; not in CI matrix โ should work but unverified |
| 3.13 | โฅ 1.13 | Listed in classifiers; not in CI matrix โ should work but unverified |
Support status
Legend
| Label | Meaning |
|---|---|
| โ Stable | Tested in CI; API is stable |
| ๐งช Experimental | Available but not yet guaranteed-stable |
| โ ๏ธ Best-effort | Works in practice; known constraints documented |
| โณ Planned | On roadmap for a future release |
| โ Not supported | Out of scope for the current release |
| ๐ Opt-in | Disabled by default; explicitly enabled by the user |
Backend support
| Backend | Forward | AMP | torch.compile | CI coverage | Status | Notes |
|---|---|---|---|---|---|---|
| CPU | โ | โ ๏ธ bfloat16 | โ | Full CI (Ubuntu Py 3.10/3.11/3.12) | โ Stable | Compile overhead for small graphs |
| CUDA | โ | โ ๏ธ op-dependent | โ | Local only โ no GPU runners | โ Stable | index_add_ requires dtype match under float16 |
| MPS (Apple Silicon) | โ | โ ๏ธ limited | โ ๏ธ partial | macOS smoke CI (import + build) | โ ๏ธ Best-effort | PyTorch operator coverage varies |
| Linux | โ | โ | โ | Full CI (ubuntu-latest) | โ Stable | Primary CI platform |
| Windows | โ | โ | โ | Smoke CI (Py 3.11) | โ ๏ธ Best-effort | Imports + build + twine + dashboard CLI |
| macOS | โ | โ ๏ธ limited | โ ๏ธ | Smoke CI (Py 3.11) | โ ๏ธ Best-effort | MPS path; same smoke as Windows |
| Multi-GPU (DDP) | ๐งฐ helpers | โ ๏ธ user-managed | โ ๏ธ user-managed | No CI | ๐งฐ Helpers only | Rank-zero / world-size / barrier; full multi-GPU framework out of scope |
โ ๏ธ Best-effort backend: MPS support depends on PyTorch operator coverage per release. CPU workflows are fully tested; MPS-specific AMP/compile paths may fall back or be skipped.
โ ๏ธ Windows/macOS smoke CI: automated import, build, twine-check, and dashboard CLI smoke on every pull request. Full pytest still runs on Ubuntu only.
Feature support
| Feature | Status | Notes |
|---|---|---|
Vector node features [N, D] |
โ Stable | LinearMessagePassing, "linear" factory |
2-D spatial node features [N, C, H, W] |
โ Stable | All four spatial layers |
3-D volumetric node features [N, C, D, H, W] |
โ Stable | spatial_rank=3 |
| Arbitrary-rank tensors (rank โฅ 4 trailing) | โ Out of scope | Only vector, 2-D, and 3-D layouts ship โ see Scope boundaries |
Edge weights [E] |
โ Stable | All layers |
Vector edge features [E, D_e] |
โ Stable | GAT, SAGE, GIN |
Spatial edge features [E, C_e, H, W] |
โ ๏ธ Best-effort | ConvMP (concat); GAT (mean-pooled scalar bias); SAGE/GIN (full) |
Volumetric edge features [E, C_e, D, H, W] |
โ ๏ธ Best-effort | Same as spatial; pass spatial_rank=3 |
GraphTransformerLayer (vector node features) |
๐งช Experimental | tgraphx.layers.graph_transformer.GraphTransformerLayer; positional / Laplacian / edge-bias supported |
Heterogeneous graphs (container + batch + HeteroConv + classifiers) |
๐งช Experimental | HeteroGraph, HeteroGraphBatch, HeteroConv, HeteroGraphClassifier, HeteroNodeClassifier; vector node features |
| Temporal graphs (container + batch + readout + classifiers) | ๐งช Experimental | TemporalGraphSequence, TemporalGraphBatch, temporal_readout, TemporalGraphClassifier, TemporalGraphRegressor; snapshot-loop pattern (no recurrent memory module) |
| Learned graph construction (soft adjacency, edge scorer) | โ Stable | tgraphx.learned_graph โ discrete top-k is non-differentiable |
| PyG / DGL data converters (homogeneous + hetero) | โ Opt-in | tgraphx.interop โ lazy imports; data-only, not API replacement |
| MLflowLogger | โ Opt-in | Lazy mlflow import; pip install "tgraphx[mlflow]" |
| Dashboard | ๐ Opt-in | Launch explicitly; zero overhead when off |
| Offline dashboard export | โ Stable | --export-html or export_dashboard_html() |
| Multi-run dashboard | โ Stable | Point --logdir at parent directory |
| Hardware monitoring | ๐ Opt-in | pip install "tgraphx[monitoring]" |
| TensorBoard logging | ๐ Opt-in | pip install "tgraphx[tracking]"; TensorBoardLogger |
Scalability support
| Feature | Status | Notes |
|---|---|---|
ConvMessagePassing chunked forward |
โ Stable | aggr="sum" / "mean"; max falls back with warning |
TensorGraphSAGELayer chunked forward |
โ Stable | mean and max; pass chunk_size=K to forward() |
TensorGINLayer chunked forward |
โ Stable | sum aggregation; pass chunk_size=K to forward() |
TensorGATLayer chunked forward |
โ Stable | Two-pass log-sum-exp; pass chunk_size=K to forward(); output matches unchunked within float32 tolerance |
build_grid_graph / build_grid_graph_3d |
โ Stable | O(E) โ scales well |
build_random_graph |
โ Stable | O(E) sample mode for large N |
build_knn_graph / build_radius_graph |
โ ๏ธ Best-effort | O(Nยฒ) time; chunk_size=K reduces peak memory to O(KรN) |
build_fully_connected_graph |
โ ๏ธ Best-effort | O(Nยฒ) edges; N > 5 000 emits warning |
build_iou_graph |
โ ๏ธ Best-effort | O(Nยฒ) IoU; chunk_size=K reduces peak memory to O(KรN) |
| Dashboard metrics API | โ Stable | Incremental ?since_row=N; --max-metric-rows cap; byte-seek tail-read |
| Subgraph / k-hop / neighbour sampling | โ Stable | tgraphx.sampling + SubgraphDataLoader / NeighborSamplerLoader |
| Random-walk sampling | โ Stable | random_walk_sample(graph, seeds, walk_length, โฆ); deterministic with seed |
| Hetero sampling (induced + per-relation neighbour) | โ Stable | tgraphx.hetero_induced_subgraph, tgraphx.hetero_neighbor_sample |
| Temporal window sampling (sequence + batch) | โ Stable | tgraphx.temporal_window_sample, tgraphx.temporal_window_sample_batch |
| Distributed (DDP) helpers | โ Stable | tgraphx.distributed: rank-zero, world-size, barrier; never auto-initialises DDP |
โ ๏ธ Scalability warning:
build_knn_graph,build_radius_graph,build_fully_connected_graph, andbuild_iou_graphuse pairwisetorch.cdistor enumerate all pairs. Memory and time grow as O(Nยฒ). Awarnings.warnis emitted when node count exceeds the threshold (10 000 for kNN/radius, 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.
Attention support
| Feature | Status | Notes |
|---|---|---|
Scalar attention per (edge, head) |
โ Stable | Default attention_mode="scalar" |
Per-channel attention per (edge, head, channel) |
๐งช Experimental | attention_mode="channel"; score tensor [E, K, C_head] |
| Vector edge attention bias | โ Stable | use_edge_features=True, edge_dim=D |
| Spatial edge attention bias (2-D / 3-D) | โ ๏ธ Best-effort | Accepted; mean-pooled to scalar before projection |
| Per-pixel attention | โ Out of scope | Naive [E, K, H, W] score tensor is memory-prohibitive; deferred until a memory-safe variant is designed |
| Per-voxel attention | โ Out of scope | Same reason as per-pixel; [E, K, D, H, W] score tensor is memory-prohibitive |
Limitations
TGraphX is a focused tensor-aware GNN library, not a drop-in replacement for PyTorch Geometric or DGL. Detailed write-ups live in docs/limitations.md; the high-level boundaries are:
- Not a PyG / DGL drop-in.
tgraphx.interopships data converters (homogeneous + hetero, lazy imports), but layer APIs and call conventions differ. AttentionMessagePassingis not GAT. It uses per-edge sigmoid gating without destination-wise softmax. UseTensorGATLayerfor true multi-head GAT.- Scalar (default) and per-channel (
attention_mode="channel", experimental) attention are the supportedTensorGATLayermodes. Per-pixel / per-voxel attention is intentionally not implemented โ naive scores[E, K, H, W]are memory-prohibitive and a memory-safe variant has not yet been designed. - GAT edge features (vector or matching-rank spatial โ mean-pooled).
TensorGATLayeraccepts[E, edge_dim]vectors and matching-rank spatial tensors ([E, edge_dim, H, W]forspatial_rank=2,[E, edge_dim, D, H, W]forspatial_rank=3). Spatial tensors are mean-pooled before the per-(edge, head)attention bias projection. Mismatched-rank edges raiseNotImplementedError. UseTensorGraphSAGELayerorTensorGINLayerfor spatial edge features without pooling. - Supported node feature ranks: vector
[N, D], 2-D spatial[N, C, H, W], 3-D volumetric[N, C, D, H, W]. Universal arbitrary-rank tensor support across every layer is a v0.3 design discussion. - Hetero / temporal workflows are experimental.
HeteroConv,HeteroGraphClassifier,HeteroNodeClassifier,TemporalGraphClassifier,TemporalGraphRegressorexist + are tested, but the surface is intentionally small (vector-feature hetero; stateless snapshot-loop temporal). Full TGN / TGAT-style recurrent memory and tensor-aware hetero classifiers are deferred. - Graph builders cover common structural patterns (grid, kNN, radius, IoU, fully connected, random); custom topology is up to the user. kNN / radius / IoU / fully-connected scale O(Nยฒ) and emit warnings on large N.
- Patch helpers require exact-divisible dimensions by default;
use
padding="auto"to right-pad. - Multi-GPU. TGraphX provides DDP-aware helpers
(
tgraphx.distributed) and a single-process smoke example, not an automatic multi-GPU training framework. - Dashboard is a local-first lightweight monitor, not a TensorBoard replacement. See docs/dashboard.md.
- Differentiability: all learned parameters are end-to-end
differentiable; graph topology (
edge_index) is user-supplied and not learned by the model.
GNN family coverage
| GNN family | Implemented? | Tested? | Limitations |
|---|---|---|---|
| Tensor-aware GCN-style (Conv message passing) | โ
ConvMessagePassing |
โ | 2-D [N, C, H, W] and 3-D [N, C, D, H, W]; edge features must be matching-rank spatial with channel count = node channel count |
| Tensor-aware GAT (multi-head) | โ
TensorGATLayer |
โ | 2-D and 3-D node features (spatial_rank=2/3); scalar attention per (edge, head); vector [E, D_e] and matching-rank spatial / volumetric edge features (mean-pooled); no per-pixel / per-voxel attention |
| Tensor-aware GraphSAGE | โ
TensorGraphSAGELayer |
โ | 2-D and 3-D node features (spatial_rank=2/3); mean / max only; no LSTM aggregator |
| Tensor-aware GIN / GINEConv | โ
TensorGINLayer |
โ | 2-D and 3-D node features (spatial_rank=2/3) |
| MPNN-style custom layer | โ
TensorMessagePassingLayer base |
โ (subclass test + example) | โ |
| Edge-conditioned MP (spatial / volumetric) | โ
ConvMessagePassing, TensorGATLayer (mean-pooled), TensorGraphSAGELayer, TensorGINLayer |
โ | edge features [E, C_e, H, W] (2-D) or [E, C_e, D, H, W] (3-D, matching the layer's spatial_rank) |
Per-edge edge_weight ([E]) |
โ all four message-passing layers, 2-D and 3-D | โ | scales messages before aggregation; on GAT applied after softmax-normalised attention |
| 3-D / volumetric node features | โ
ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer, TensorGINLayer |
โ | [N, C, D, H, W]; pass spatial_rank=3 to GAT/SAGE/GIN, or (C, D, H, W) in_shape to ConvMessagePassing. DeepCNNAggregator is rank-aware. LinearMessagePassing covers vector [N, D] and is unaffected. |
| Edge-conditioned MP (vector) | โ
TensorGATLayer, TensorGraphSAGELayer, TensorGINLayer |
โ | edge features [E, D_e]; edge_features_kind="vector" |
aggr="sum"|"mean"|"max" base |
โ all three modes | โ hand-computed + backward | ConvMessagePassing aggr="max" routes through scatter_max |
| Graph Transformer (vector features) | ๐งช GraphTransformerLayer (+ degree / Laplacian / adjacency-bias encodings) |
โ | global multi-head self-attention [N, D]; O(Nยฒ); tensor-aware variant intentionally deferred |
| Heterogeneous graphs | ๐งช HeteroGraph, HeteroGraphBatch, HeteroConv, HeteroGraphClassifier, HeteroNodeClassifier |
โ | vector features; relation-dispatch wrapper + per-type classifier; tensor-aware hetero classifiers intentionally deferred |
| Temporal graphs | ๐งช TemporalGraphSequence, TemporalGraphBatch, temporal_readout, TemporalGraphClassifier, TemporalGraphRegressor |
โ | stateless snapshot-loop + readout; recurrent (TGN / TGAT) memory intentionally deferred |
| Sampling (homogeneous) | โ
induced_subgraph, edge_subgraph, k_hop_subgraph, sample_nodes, sample_edges, neighbor_sample, random_walk_sample |
โ | All deterministic with seed; per-call generator (no global RNG side effects) |
| Sampling loaders | โ
SubgraphDataLoader, NeighborSamplerLoader |
โ | Plain Python iterables; deterministic with seed |
| Sampling (hetero / temporal โ v0.2.8) | โ
hetero_induced_subgraph, hetero_neighbor_sample, temporal_window_sample, temporal_window_sample_batch |
โ | Per-relation fanouts; window slicing on equal- and variable-length temporal batches |
| Distributed (DDP) helpers | โ
tgraphx.distributed (get_rank, get_world_size, is_rank_zero, rank_zero_print, rank_zero_only, barrier) |
โ | Never auto-initialises DDP; safe to import in single-process / CPU-only contexts |
| Learned graph construction | โ
tgraphx.learned_graph |
โ | soft adjacency, EdgeScorer (differentiable); top-k discrete (non-diff) |
| PyG / DGL converters (homogeneous + hetero) | โ
tgraphx.interop |
โ | data converters only (lazy imports); not an API replacement |
| MLflowLogger | โ
tgraphx.tracking.MLflowLogger |
โ | lazy mlflow import; opt-in via tgraphx[mlflow] extra |
| Arbitrary-rank tensor support across every layer | โ Out of scope | โ | Only vector, 2-D, 3-D node-feature layouts ship โ see Scope boundaries |
Project structure
TGraphX/
โโโ tgraphx/
โ โโโ __init__.py # public API re-exports
โ โโโ core/ # Graph, GraphBatch, hetero & temporal containers + batches
โ โ โโโ graph.py # Graph, GraphBatch
โ โ โโโ graph_utils.py # edge topology helpers
โ โ โโโ dataloader.py # GraphDataset, GraphDataLoader
โ โ โโโ hetero_graph.py # HeteroGraph (typed nodes + edges)
โ โ โโโ hetero_batch.py # HeteroGraphBatch
โ โ โโโ temporal.py # TemporalGraphSequence
โ โ โโโ temporal_batch.py# TemporalGraphBatch
โ โ โโโ utils.py # load_config, get_device
โ โโโ layers/ # tensor-aware GNN layers + transformer + hetero/temporal
โ โ โโโ base.py # TensorMessagePassingLayer, LinearMessagePassing
โ โ โโโ conv_message.py # ConvMessagePassing
โ โ โโโ attention_message.py# AttentionMessagePassing (legacy sigmoid)
โ โ โโโ gat.py # TensorGATLayer (multi-head, scalar / channel modes, chunked)
โ โ โโโ sage.py # TensorGraphSAGELayer
โ โ โโโ gin.py # TensorGINLayer / GINEConv
โ โ โโโ graph_transformer.py# GraphTransformerLayer (vector)
โ โ โโโ transformer_encodings.py # degree / Laplacian / adjacency-bias encodings
โ โ โโโ hetero.py # HeteroConv (relation-dispatch wrapper)
โ โ โโโ hetero_readout.py # hetero_mean/sum/max/concat_pool
โ โ โโโ temporal_readout.py # temporal_readout (last/mean/max + mask)
โ โ โโโ factory.py # make_layer()
โ โ โโโ aggregator.py # DeepCNNAggregator
โ โ โโโ safe_pool.py # SafeMaxPool2d
โ โ โโโ _scatter.py / _dim.py # internal helpers
โ โโโ models/ # task heads + factories
โ โ โโโ factory.py # build_model(), build_model_from_config()
โ โ โโโ graph_classifier.py # GraphClassifier
โ โ โโโ node_classifier.py # NodeClassifier
โ โ โโโ edge_predictor.py # EdgePredictor
โ โ โโโ regressors.py # NodeRegressor, GraphRegressor
โ โ โโโ hetero_models.py # HeteroGraphClassifier, HeteroNodeClassifier
โ โ โโโ temporal_models.py # TemporalGraphClassifier, TemporalGraphRegressor
โ โ โโโ cnn_encoder.py # CNNEncoder
โ โ โโโ cnn_gnn_model.py # CNN_GNN_Model
โ โ โโโ pre_encoder.py # PreEncoder (optional ResNet-18)
โ โโโ graph_builders.py # grid / kNN / radius / IoU / fully-connected / random + patch helpers
โ โโโ learned_graph.py # soft adjacency, EdgeScorer, top-k edges
โ โโโ interop.py # PyG / DGL converters (lazy imports; homogeneous + hetero)
โ โโโ sampling.py # induced / edge / k-hop / neighbour / random-walk sampling
โ โโโ sampling_loaders.py # SubgraphDataLoader, NeighborSamplerLoader
โ โโโ hetero_sampling.py # hetero_induced_subgraph, hetero_neighbor_sample (v0.2.8)
โ โโโ temporal_sampling.py # temporal_window_sample, temporal_window_sample_batch (v0.2.8)
โ โโโ distributed.py # rank-zero / world-size / barrier helpers
โ โโโ training.py # train_epoch, evaluate, fit, set_seed, checkpointing, metrics
โ โโโ tracking.py # CSVLogger, TensorBoardLogger, MLflowLogger, write_graph_stats
โ โโโ performance.py # env_report, recommended_device, estimate_message_memory
โ โโโ dashboard/ # local-first training dashboard (off by default)
โ โโโ app.py # DashboardServer, export_dashboard_html, CLI main()
โ โโโ __init__.py # launch_dashboard, launch_dashboard_background
โ โโโ static/ # dashboard.css, dashboard.js (packaged in wheel)
โโโ tests/ # 1230+ tests; CPU-safe, deterministic
โโโ examples/ # 30+ self-contained demos (see "Examples" section)
โโโ benchmarks/ # benchmark_layers.py / benchmark_graph_builders.py / benchmark_sampling.py
โโโ docs/ # API reference, limitations, roadmap, performance, โฆ
โโโ .github/workflows/ # tests.yml (Ubuntu full + macOS/Windows smoke), publish.yml
โโโ pyproject.toml
โโโ CHANGELOG.md
โโโ LICENSE
Development
# Install with dev dependencies (pytest, build, twine)
pip install -e ".[dev]"
# Run the test suite (CPU tests always run; CUDA/MPS skipped if unavailable)
pytest
# Run a specific test file
pytest tests/test_layers.py -v
# Run the examples
python examples/minimal_spatial_message_passing.py
python examples/minimal_graph_classifier.py
Authorship
Software author and maintainer: Arash Sajjadi, PhD Candidate in Computer Science, University of Saskatchewan (arash.sajjadi@usask.ca)
Academic supervision: Mark Eramian, PhD Supervisor / Academic Advisor, University of Saskatchewan
Related preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning Arash Sajjadi and Mark Eramian โ arXiv:2504.03953
The software package is developed and maintained by Arash Sajjadi. Mark Eramian is Arash Sajjadi's PhD supervisor and co-author of the related academic preprint. Software authorship and paper co-authorship are separate roles; both are acknowledged accurately above.
Citation
If you use TGraphX in your research, please cite:
@misc{sajjadi2025tgraphxtensorawaregraphneural,
title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
author={Arash Sajjadi and Mark Eramian},
year={2025},
eprint={2504.03953},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.03953},
}
License
TGraphX is released under the MIT License.
Questions, issues, or contributions are welcome โ please open a GitHub issue or pull request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tgraphx-0.2.9.tar.gz.
File metadata
- Download URL: tgraphx-0.2.9.tar.gz
- Upload date:
- Size: 369.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
60f02b3745c2a16b94fd74d1a3f12e670d32bddfbd12052340c2ffdd490fad6c
|
|
| MD5 |
140d08a357f86730e9f96480368ed74e
|
|
| BLAKE2b-256 |
0b978e9a66a340670df65406dcb67527e9d8a14826939d84995fe9d32ee169da
|
File details
Details for the file tgraphx-0.2.9-py3-none-any.whl.
File metadata
- Download URL: tgraphx-0.2.9-py3-none-any.whl
- Upload date:
- Size: 259.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.15
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5621ada4b935e579e0f1a28e1235e21fc0a4d57c8fb01d236e5a4909d099ef1a
|
|
| MD5 |
18133bd56a437a40a2e9fcea4ff7973f
|
|
| BLAKE2b-256 |
a2397c3008119155d642f40c46c09db35a97cade95dd5aca7f91eba377b7b9db
|