Skip to main content

Grid Foundation Model

Project description

GridFM logo

gridfm-graphkit

DOI Docs Coverage OpenSSF Best Practices OpenSSF Scorecard Python License

This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.


Installation

Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)

python -m venv venv
source venv/bin/activate

Install gridfm-graphkit from PyPI

pip install gridfm-graphkit

torch-scatter is a required dependency. It cannot be bundled in pyproject.toml because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.

Get PyTorch + CUDA version for torch-scatter

TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")

Install the correct torch-scatter wheel

pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html

For documentation generation and unit testing, install with the optional dev and test extras:

pip install "gridfm-graphkit[dev,test]"

CLI commands

Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.

gridfm_graphkit <command> [OPTIONS]

Available commands:

  • train - Train a new model from scratch
  • finetune - Fine-tune an existing pre-trained model
  • evaluate - Evaluate model performance on a dataset
  • predict - Run inference and save predictions

Training Models

gridfm_graphkit train --config path/to/config.yaml

Arguments

Argument Type Description Default
--config str Required. Path to the training configuration YAML file. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow tracking/logging directory. mlruns
--data_path str Root dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False
--mp_context str DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. None

Examples

Standard Training:

gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data

Fine-Tuning Models

gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Fine-tuning configuration file. None
--model_path str Required. Path to a pre-trained model state dict. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Root dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False
--mp_context str DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. None

Evaluating Models

gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Path to evaluation config. None
--model_path str Path to the trained model state dict. None
--normalizer_stats str Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Dataset directory. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--compute_dc_ac_metrics flag Compute ground-truth AC/DC power balance metrics on the test split. False
--save_output flag Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). False
--mp_context str DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. None

Example with saved normalizer stats

When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:

gridfm_graphkit evaluate \
  --config examples/config/HGNS_PF_datakit_case118.yaml \
  --model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
  --normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
  --data_path data

Note: The --normalizer_stats flag only affects normalizers with fit_strategy = "fit_on_train" (e.g. HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.


Running Predictions

gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt

Arguments

Argument Type Description Default
--config str Required. Path to prediction config file. None
--model_path str Path to trained model state dict. Optional; may be defined in config. None
--normalizer_stats str Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. None
--exp_name str MLflow experiment name. timestamp
--run_name str MLflow run name. run
--log_dir str MLflow logging directory. mlruns
--data_path str Dataset directory. data
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--plugins list[str] Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. []
--num_workers int Override data.workers from YAML. Use 0 to debug worker crashes. None
--dataset_wrapper_cache_dir str Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. None
--output_path str Directory where predictions are saved as <grid_name>_predictions.parquet. data
--compile [MODE] str Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. None
--bfloat16 flag Cast model to torch.bfloat16 (model.to(torch.bfloat16)). False
--tf32 flag Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). False
--profiler str Enable Lightning profiler (simple, advanced, pytorch). None
--mp_context str DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. None

Benchmarking Dataloader Throughput

gridfm_graphkit benchmark --config path/to/config.yaml

Arguments

Argument Type Description Default
--config str Required. Path to configuration YAML file. None
--data_path str Root dataset directory. data
--epochs int Number of epochs to iterate through the train dataloader. 3
--dataset_wrapper str Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. None
--dataset_wrapper_cache_dir str Directory for dataset wrapper disk cache. None
--num_workers int Override data.workers from YAML. None
--plugins list[str] Python packages to import for plugin registration. []
--mp_context str DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. None

Use built-in help for full command details:

gridfm_graphkit --help
gridfm_graphkit <command> --help

Running Tests

Unit and Integration Tests

Install the test dependencies first (if not already done):

pip install -e .[dev,test]

Run the full unit test suite:

pytest ./tests

Run the base set integration tests:

pytest ./integrationtests/test_base_set.py

Running Base Set Tests on an LSF Cluster (GPU)

To submit the base set integration tests as an interactive LSF job with GPU access, use bsub. Adjust the paths to match your environment:

bsub -gpu "num=1" \
     -n 16 \
     -R "rusage[mem=32GB] span[hosts=1]" \
     -Is \
     -J gridfm_base_set_tests \
     /bin/bash -c "
       cd /path/to/gridfm-graphkit && \
       export PATH=/path/to/cuda/bin:\$PATH \
               CUDA_HOME=/path/to/cuda \
               LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
       source /path/to/venv/bin/activate && \
       pytest ./integrationtests/test_base_set.py
     "

Key bsub options used above:

Option Description
-gpu "num=1" Request 1 GPU
-n 16 Request 16 CPU slots
-R "rusage[mem=32GB] span[hosts=1]" Reserve 32 GB of memory on a single host
-Is Run as an interactive shell session
-J <job_name> Assign a name to the job

Concrete example (adapt paths to your cluster setup):

bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gridfm_graphkit-0.0.7.tar.gz (67.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gridfm_graphkit-0.0.7-py3-none-any.whl (74.5 kB view details)

Uploaded Python 3

File details

Details for the file gridfm_graphkit-0.0.7.tar.gz.

File metadata

  • Download URL: gridfm_graphkit-0.0.7.tar.gz
  • Upload date:
  • Size: 67.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for gridfm_graphkit-0.0.7.tar.gz
Algorithm Hash digest
SHA256 364c3c0b2a1e7ec6a968a2a578ef34fcad7e1dd5ac807a8859704200f94f935a
MD5 717f0d40b2990cf01c6ad6c7be9a1596
BLAKE2b-256 b109fa3fe8d6e1f72fa62214943143be201b663955d4772eb61f0a532fd46a9c

See more details on using hashes here.

File details

Details for the file gridfm_graphkit-0.0.7-py3-none-any.whl.

File metadata

File hashes

Hashes for gridfm_graphkit-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 7ccd629016ba129386df6fcf8cbae4493568be695484afd34b07733e56b74884
MD5 29f1d3cde40ecfeb9ae0dfc3b69be8d0
BLAKE2b-256 0fa34bcca27b7e6045def9c1a058ba62567a3aa4faab9bb70e502ef8ea070002

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page