Skip to main content

Tensor-aware graph neural networks preserving spatial node feature layouts

Project description

TGraphX logo

TGraphX

Tests PyPI version Python 3.10+ PyTorch 1.13+ License: MIT

Preprint: TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning · Sajjadi & Eramian, arXiv 2025

Developed by Arash Sajjadi, PhD Candidate in Computer Science, University of Saskatchewan. Academic supervision: Mark Eramian.

TGraphX is a PyTorch library for graph neural networks whose node features are multi-dimensional tensors[C, H, W] image-patch feature maps, [C, D, H, W] volumetric maps, or plain vectors [D]. Convolutional message passing operates directly on spatial node features, so local structure is never destroyed by flattening.


Install

pip install tgraphx

Optional extras (all lazy-imported, none required for the base package):

pip install "tgraphx[tracking]"     # TensorBoard
pip install "tgraphx[mlflow]"       # MLflow
pip install "tgraphx[monitoring]"   # psutil + pynvml (dashboard hardware panel)
pip install "tgraphx[pyg]"          # PyTorch Geometric dataset adapter
pip install "tgraphx[ogb]"          # OGB dataset adapter
pip install "tgraphx[pillow]"       # ImageFolder dataset
pip install "tgraphx[dev]"          # pytest + build + twine

DGL has platform-sensitive wheels and is not packaged as a TGraphX extra. Follow the DGL install guide; the DGL adapter is available either way once DGL is on the import path.

For a custom PyTorch build (CPU-only or specific CUDA), install PyTorch first, then TGraphX:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install tgraphx

60-second quickstart

Vector node features:

import torch
from tgraphx import Graph, LinearMessagePassing

x = torch.randn(8, 32)
edge_index = torch.stack([torch.arange(8), (torch.arange(8) + 1) % 8])
g = Graph(x, edge_index)

layer = LinearMessagePassing(in_shape=(32,), out_shape=(64,))
out = layer(g.node_features, g.edge_index)   # [8, 64]
out.sum().backward()

Spatial [C, H, W] node features — preserved through message passing:

import torch
from tgraphx import Graph, ConvMessagePassing

N, C, H, W = 6, 16, 8, 8
g = Graph(torch.randn(N, C, H, W), torch.stack(
    [torch.arange(N), (torch.arange(N) + 1) % N]))

layer = ConvMessagePassing(in_shape=(C, H, W), out_shape=(32, H, W))
out = layer(g.node_features, g.edge_index)   # [6, 32, 8, 8]
out.sum().backward()

A Colab tutorial walks through every workflow: Open in Colab


Why tensor-aware GNNs

Standard GNN frameworks expect flat vector node features. Flattening a [C, H, W] feature map into a [C·H·W] vector discards the spatial structure that makes CNNs effective. TGraphX keeps each node's representation as a tensor and applies 1×1 convolutions during message passing, so every neighbourhood aggregation step acts like a miniature CNN across neighbouring feature maps.

M_{i→j} = φ( X_i, X_j, E_{i→j} )
A_j     = AGG_{i ∈ N(j)} M_{i→j}
X'_j    = ψ( X_j, A_j )

Spatial dimensions H and W are preserved through every learned transform. Aggregations are permutation-invariant; the layers are permutation-equivariant under node relabelling (verified by tests/test_math_invariants_v030.py).


Stable core APIs

Component Class Input node shape Notes
Graph data structure Graph, GraphBatch any Validated; supports .to(device)
Tensor-aware GCN-style message passing ConvMessagePassing [N, C, H, W] aggr="sum" / "mean" / "max"; chunked forward
Tensor-aware multi-head GAT TensorGATLayer [N, C, H, W] Softmax over incoming edges per destination per head; two-pass log-sum-exp chunked forward
Tensor-aware GraphSAGE TensorGraphSAGELayer [N, C, H, W] Mean / max aggregation; chunked forward
Tensor-aware GIN / GINEConv TensorGINLayer [N, C, H, W] (1+ε)·h_j + Σ h_i; chunked forward
Vector message passing LinearMessagePassing [N, D] Base layer with linear projections
Custom-layer base class TensorMessagePassingLayer any Override message / update
Vector model-zoo GCNConv, GATv2Conv, APPNP [N, D] New in v0.3.0; permutation-equivariance and tiny-overfit tested
Pooling helpers global_mean_pool, global_sum_pool, global_max_pool [N, *] Per-graph reductions
CNN patch encoder CNNEncoder [N, C_in, pH, pW] Outputs spatial feature maps
Graph classification GraphClassifier [N, C, H, W] Mean / sum / max readout
Node classification NodeClassifier [N, D] Vector features

3-D volumetric node features [N, C, D, H, W] are supported by ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer, and TensorGINLayer (pass spatial_rank=3).

Factories

from tgraphx import make_layer, build_model

layer = make_layer("gat", in_shape=(8, 4, 4), out_shape=(16, 4, 4), heads=2, residual=True)
model = build_model(
    task="graph_classification", layer="conv",
    in_shape=(3, 4, 4), hidden_shape=(8, 4, 4),
    num_layers=2, num_classes=10, pooling="mean",
)
Task [N,D] [N,C,H,W] [N,C,D,H,W]
node_classification yes yes yes
node_regression yes yes yes
graph_classification yes yes yes
graph_regression yes yes yes
edge_prediction yes yes (spatial pool) yes (spatial pool)

Datasets, transforms, metrics, benchmarks

TGraphX ships a unified dataset registry, native synthetic datasets, folder-backed datasets, and optional adapters for torchvision, PyG, DGL, and OGB. TGraphX does not redistribute third-party datasets; adapters delegate download and parsing to the upstream library, and any download requires the user to pass download=True explicitly.

from tgraphx.datasets import list_datasets, dataset_info, get_dataset

print(list_datasets())                        # 34 registered datasets
ds = get_dataset("synthetic:patch_graph", num_graphs=32, seed=0)
g = ds[0]

Cache layout: <root> (or $TGRAPHX_DATA, or ~/.cache/tgraphx/datasets). Inspect with tgraphx.datasets.cache_summary(); clear with clear_cache(...).

Layer Highlights
tgraphx.datasets Native synthetic + folder + torchvision/PyG/DGL/OGB adapters; safe atomic downloads with SHA-256 checksums and path-traversal-blocked archive extraction
tgraphx.transforms Compose, AddSelfLoops, RemoveSelfLoops, ToUndirected, NormalizeFeatures, StandardizeFeatures, AddDegreeFeatures, RandomNodeSplit, RandomLinkSplit, AddDegreeEncoding, AddLaplacianEigenvectors, PatchifyImage, BuildGridGraph, …
tgraphx.metrics Pure-PyTorch accuracy, top_k_accuracy, precision_recall_f1, classification_report, mae, mse, rmse, r2_score, hits_at_k, mean_reciprocal_rank, ndcg_at_k, roc_auc, average_precision
benchmarks/ benchmark_dataset_loading.py, benchmark_training_synthetic.py, benchmark_tensor_vs_flatten.py, benchmark_transforms.py, benchmark_metrics.py, benchmark_layers.py, benchmark_graph_builders.py, benchmark_sampling.py. Every benchmark accepts --small and --output JSON.

Detailed write-ups: docs/datasets.md, docs/transforms.md, docs/metrics.md, docs/benchmarks.md, docs/dataset_license_policy.md.

Importing tgraphx, tgraphx.datasets, tgraphx.transforms, tgraphx.metrics, tgraphx.experiments, or tgraphx.explain does not import torch_geometric / dgl / ogb.


Experiment manager

tgraphx.experiments (new in v0.3.0) is a lightweight, dashboard-compatible experiment manager. YAML / JSON configs are validated against an explicit schema (no eval, no exec). Every run writes its artefacts under run_dir only:

runs/<run_name>/<timestamp>/
├── run_metadata.json           # run name, status, device, seed, version
├── experiment_config.json      # exact config copy
├── experiment_summary.json     # epochs, best metric, final loss
├── metrics.csv                 # dashboard-compatible
└── checkpoints/{best,latest}.pt
from tgraphx.experiments import Runner, load_config

cfg = load_config("examples/configs/synthetic_patch_graph.yaml")
runner = Runner(cfg)
history = runner.fit()

Or via the CLI (registered as console scripts):

tgraphx-train  examples/configs/synthetic_patch_graph.yaml
tgraphx-grid   examples/configs/grid_sweep.yaml
tgraphx-report runs/                  # → Markdown summary

Built-in callbacks: EarlyStopping, ModelCheckpoint, CSVLoggerCallback, LearningRateLogger. Multi-seed and grid sweeps via GridRunner. The dashboard reads the same files when you run tgraphx-dashboard --logdir runs/<run_name>.

See docs/experiments.md for the full schema.


Explainability

tgraphx.explain (new in v0.3.0) provides diagnostic explainability tools. They run on CPU by default, do not retain autograd graphs, and never make causal claims about a model's predictions.

from tgraphx.explain import (
    node_feature_saliency, integrated_gradients,
    edge_perturbation_attribution, attention_to_edge_scores,
    patch_saliency_to_image_grid,
    export_explanation_metadata, export_edge_scores_csv,
)

sal      = node_feature_saliency(model, graph, target=label)
ig       = integrated_gradients(model, graph, target=label, steps=16)
edge_imp = edge_perturbation_attribution(model, graph, target=label, max_edges=64)

# For TensorGATLayer outputs.
out, attn = layer(x, edge_index, return_attention=True)
edge_scores = attention_to_edge_scores(attn, edge_index, head_reduce="mean")

# Patch-level heatmap from grid_shape metadata.
heatmap = patch_saliency_to_image_grid(sal, grid_shape=graph.metadata["grid_shape"])

# Export artefacts the dashboard can render (explicit paths only).
export_explanation_metadata("runs/demo/explanation_metadata.json",
                            method="saliency", target=int(label))
export_edge_scores_csv("runs/demo/explanation_edges.csv",
                       graph.edge_index, edge_scores, top_k=20)

See docs/explainability.md for examples and limits.


Optional integrations

Adapter Install Notes
Native synthetic / folder datasets (none) Deterministic, learnable, no network
Torchvision-backed datasets already a TGraphX base dependency MNIST, CIFAR-10/100, SVHN, STL-10, FakeData, … converted to patch graphs
PyTorch Geometric pip install "tgraphx[pyg]" Planetoid, TUDataset, generic adapter; data-format converters only
DGL follow upstream install citation graphs, generic adapter
OGB pip install "tgraphx[ogb]" node / link / graph property prediction wrappers + OGBEvaluatorWrapper
MLflow pip install "tgraphx[mlflow]" MLflowLogger, lazy import
TensorBoard pip install "tgraphx[tracking]" TensorBoardLogger, lazy import
Hardware monitoring pip install "tgraphx[monitoring]" psutil + pynvml for the dashboard's hardware panel

TGraphX is interoperable with the PyG / DGL / OGB ecosystems through small data-format converters; it is not a drop-in replacement for either framework.


Backend and platform support

Platform Forward torch.compile AMP CI coverage
Linux + CPU yes yes bfloat16 (recommended) Full Ubuntu CI on Python 3.10 / 3.11 / 3.12
NVIDIA CUDA yes yes float16 / bfloat16 Local tests; no GPU runners in CI
Apple Silicon MPS yes partial partial macOS smoke CI (import + build)
Windows + CPU yes yes yes Windows smoke CI (Python 3.11)
macOS + CPU yes yes yes macOS smoke CI (Python 3.11)
Multi-GPU (DDP) helpers user-managed user-managed tgraphx.distributed rank-zero / barrier helpers; full DDP setup is the user's responsibility

For the AMP recommendations, dtype-cast policy, and per-platform caveats, see docs/performance.md. For the precise API stability classification, see docs/api_stability.md.


Dashboard, local-first, privacy

TGraphX ships a local-first training dashboard that reads run artefacts (metrics.csv, run_metadata.json, experiment_config.json, experiment_summary.json, dataset_metadata.json, transform_metadata.json, metrics_summary.json, benchmark_results.json, explanation_metadata.json, explanation_edges.csv, explanation_patch_heatmap.json, hetero_graph_metadata.json, temporal_metadata.json, sampling_metadata.json, hardware_report.json).

tgraphx-dashboard --logdir runs/demo
# → http://127.0.0.1:8765 (localhost only, no token needed)

tgraphx-dashboard --logdir runs/demo --host 0.0.0.0 --token MY_SECRET_TOKEN
tgraphx-dashboard --logdir runs/demo --export-html snapshot.html

Properties:

  • Off by default. The dashboard is opt-in; nothing about importing TGraphX runs a server.
  • Local-first. Reads files from your --logdir; never executes code, never loads checkpoints, never runs models.
  • No telemetry, no analytics, no external CDN; the offline HTML export is fully self-contained.
  • LAN access requires --token; localhost mode does not.
  • Path-traversal protected: every file read is validated against the resolved logdir.
  • Background threads are launched only when you call launch_dashboard_background(...) explicitly.

Helpers for writing metadata files the dashboard understands:

from tgraphx import (
    write_run_metadata, write_dataset_metadata, write_transform_metadata,
    write_metrics_summary, write_benchmark_results,
    write_explanation_metadata, write_experiment_config,
    write_hardware_report, write_sampling_metadata,
    write_hetero_graph_metadata, write_temporal_metadata,
)

write_graph_stats from earlier releases is preserved.

See docs/dashboard.md for the full UI, security model, and offline export.


Privacy summary

Behaviour Default
Telemetry / analytics None — never
Remote calls at import None
Dashboard Off — launch explicitly
CSV / TensorBoard / MLflow logging Off — create a logger explicitly
Hardware monitoring Off — pass include_hardware=True to env_report
Checkpoints Off — call save_checkpoint or attach ModelCheckpoint explicitly
Dataset downloads Off — adapters require download=True
Background threads None unless launch_dashboard_background is called
File writes Only to paths the user provides

No ~/.tgraphx directory or user-level config is created.


Examples and tutorials

python examples/run_all_fast_examples.py        # every CPU-safe demo
python examples/datasets_quickstart.py           # registry + synthetic datasets
python examples/transforms_metrics_demo.py       # Compose + classification report
python examples/synthetic_datasets_demo.py       # all 7 native synthetic datasets
python examples/sampling_demo_v028.py            # random walk + hetero + temporal sampling
python examples/training_with_dashboard.py      # fit() + CSVLogger → dashboard
python examples/graph_transformer_demo.py        # vector graph transformer
python examples/gat_chunking_demo.py             # GAT chunked forward parity

The full set of demos lives under examples/ (see examples/README.md when present).

Validation scripts

A small set of stand-alone validation scripts proves that the surfaces TGraphX advertises actually work end-to-end:

Script Purpose
examples/device_validation.py Runs vector + spatial layer smokes on CPU / CUDA / MPS, optionally under autocast (--amp). Emits a JSON report.
examples/dashboard_artifact_validation.py Writes every supported dashboard metadata file, exports an offline HTML snapshot, and checks for CDN / token / eval leaks.
examples/experiment_end_to_end_validation.py Trains a tiny synthetic experiment, asserts dashboard files exist, resumes from a checkpoint.
examples/explainability_end_to_end_validation.py Trains, runs saliency / integrated gradients / edge perturbation, exports dashboard-readable explanation artefacts.
examples/public_datasets/*.py Manual, opt-in scripts that exercise the torchvision / PyG / DGL / OGB adapters against real upstream loaders. They require --download, cap dataset size, skip cleanly if the optional package is missing, and never run in CI.

See docs/public_dataset_validation.md and docs/device_validation.md for the exact invocations and policy.


Boundaries

TGraphX is intentionally focused.

  • TGraphX is not a drop-in replacement for PyG or DGL; the optional adapters convert data only.
  • TGraphX provides DDP-aware helpers and a single-process smoke example, not an automatic multi-GPU training framework.
  • Per-pixel and per-voxel GAT attention scores are not shipped — naive [E, K, H, W] score tensors are memory-prohibitive. Per-channel attention is shipped as attention_mode="channel".
  • Recurrent temporal memory modules (TGN, TGAT) are not shipped; temporal workflows use a stateless snapshot-loop pattern.
  • Synthetic datasets are sanity / tutorial datasets, not benchmarks; benchmark scripts are reproducibility tools, not real-world performance comparisons.
  • kNN, radius, IoU, and fully-connected graph builders are mathematically O(N²) and warn on large N; chunked variants reduce peak memory.
  • Universal arbitrary-rank node-feature support across every layer is a future direction; the supported layouts today are vector [N, D], 2-D spatial [N, C, H, W], and 3-D volumetric [N, C, D, H, W].

Detailed limitations live in docs/limitations.md; the roadmap is in docs/roadmap.md.


Citation

@misc{sajjadi2025tgraphxtensorawaregraphneural,
    title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
    author={Arash Sajjadi and Mark Eramian},
    year={2025},
    eprint={2504.03953},
    archivePrefix={arXiv},
    primaryClass={cs.CV},
    url={https://arxiv.org/abs/2504.03953},
}

Authorship

Software author and maintainer: Arash Sajjadi, PhD Candidate in Computer Science, University of Saskatchewan (arash.sajjadi@usask.ca). Academic supervision: Mark Eramian, PhD supervisor / academic advisor, University of Saskatchewan.


License

TGraphX is released under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tgraphx-0.4.1.tar.gz (430.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tgraphx-0.4.1-py3-none-any.whl (348.7 kB view details)

Uploaded Python 3

File details

Details for the file tgraphx-0.4.1.tar.gz.

File metadata

  • Download URL: tgraphx-0.4.1.tar.gz
  • Upload date:
  • Size: 430.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for tgraphx-0.4.1.tar.gz
Algorithm Hash digest
SHA256 5d0b247b1df590208d3381bd49e8c0fb883864b4d358e182dafd6ae080b8d4ba
MD5 3ecb7e74dac9e59391f179c4a957a2a9
BLAKE2b-256 59386dd3ae936b6c5dfd0d69628ab7ba8ab7ba894c5da6bf35f43a196fbe60dc

See more details on using hashes here.

File details

Details for the file tgraphx-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: tgraphx-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 348.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for tgraphx-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1f856c2922d5575daddaed063b2c7d20ef6c4785d3792c1db6c7e654682358fd
MD5 c309fa46007caebb1107e9d74fb6b756
BLAKE2b-256 2664972264b5219d0bc336779099907894cfd96b647782db34970939227a1edb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page