Grid Foundation Model
Project description
gridfm-graphkit
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
Installation
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
python -m venv venv
source venv/bin/activate
Install gridfm-graphkit from PyPI
pip install gridfm-graphkit
torch-scatter is a required dependency. It cannot be bundled in pyproject.toml because the correct wheel depends on your PyTorch and CUDA versions, so it must be installed separately.
Get PyTorch + CUDA version for torch-scatter
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")
Install the correct torch-scatter wheel
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html
For documentation generation and unit testing, install with the optional dev and test extras:
pip install "gridfm-graphkit[dev,test]"
CLI commands
Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
gridfm_graphkit <command> [OPTIONS]
Available commands:
train- Train a new model from scratchfinetune- Fine-tune an existing pre-trained modelevaluate- Evaluate model performance on a datasetpredict- Run inference and save predictions
Training Models
gridfm_graphkit train --config path/to/config.yaml
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to the training configuration YAML file. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow tracking/logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--mp_context |
str |
DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. |
None |
Examples
Standard Training:
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/data
Fine-Tuning Models
gridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Fine-tuning configuration file. | None |
--model_path |
str |
Required. Path to a pre-trained model state dict. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--mp_context |
str |
DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. |
None |
Evaluating Models
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to evaluation config. | None |
--model_path |
str |
Path to the trained model state dict. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--save_output |
flag |
Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). |
False |
--mp_context |
str |
DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. |
None |
Example with saved normalizer stats
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
gridfm_graphkit evaluate \
--config examples/config/HGNS_PF_datakit_case118.yaml \
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
--data_path data
Note: The
--normalizer_statsflag only affects normalizers withfit_strategy = "fit_on_train"(e.g.HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.
Running Predictions
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to prediction config file. | None |
--model_path |
str |
Path to trained model state dict. Optional; may be defined in config. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--output_path |
str |
Directory where predictions are saved as <grid_name>_predictions.parquet. |
data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--mp_context |
str |
DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. |
None |
Benchmarking Dataloader Throughput
gridfm_graphkit benchmark --config path/to/config.yaml
Arguments
| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to configuration YAML file. | None |
--data_path |
str |
Root dataset directory. | data |
--epochs |
int |
Number of epochs to iterate through the train dataloader. | 3 |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--dataset_wrapper_cache_dir |
str |
Directory for dataset wrapper disk cache. | None |
--num_workers |
int |
Override data.workers from YAML. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration. | [] |
--mp_context |
str |
DataLoader multiprocessing start method (spawn, fork, forkserver). Defaults to PyTorch's automatic choice. On Linux, spawn is recommended for safety (CUDA + fork is unsafe); other choices emit a warning. |
None |
Use built-in help for full command details:
gridfm_graphkit --help
gridfm_graphkit <command> --help
Running Tests
Unit and Integration Tests
Install the test dependencies first (if not already done):
pip install -e .[dev,test]
Run the full unit test suite:
pytest ./tests
Run the base set integration tests:
pytest ./integrationtests/test_base_set.py
Running Base Set Tests on an LSF Cluster (GPU)
To submit the base set integration tests as an interactive LSF job with GPU access, use bsub. Adjust the paths to match your environment:
bsub -gpu "num=1" \
-n 16 \
-R "rusage[mem=32GB] span[hosts=1]" \
-Is \
-J gridfm_base_set_tests \
/bin/bash -c "
cd /path/to/gridfm-graphkit && \
export PATH=/path/to/cuda/bin:\$PATH \
CUDA_HOME=/path/to/cuda \
LD_LIBRARY_PATH=/path/to/cuda/lib64:\$LD_LIBRARY_PATH && \
source /path/to/venv/bin/activate && \
pytest ./integrationtests/test_base_set.py
"
Key bsub options used above:
| Option | Description |
|---|---|
-gpu "num=1" |
Request 1 GPU |
-n 16 |
Request 16 CPU slots |
-R "rusage[mem=32GB] span[hosts=1]" |
Reserve 32 GB of memory on a single host |
-Is |
Run as an interactive shell session |
-J <job_name> |
Assign a name to the job |
Concrete example (adapt paths to your cluster setup):
bsub -gpu "num=1" -n 16 -R "rusage[mem=32GB] span[hosts=1]" -Is -J hpo_trial_190 /bin/bash -c "cd /dccstor/terratorch/users/rkie/gitco/gridfm-graphkit && export PATH=/opt/share/cuda-12.8.1/bin:\$PATH CUDA_HOME=/opt/share/cuda-12.8.1 LD_LIBRARY_PATH=/opt/share/cuda-12.8.1/lib64:\$LD_LIBRARY_PATH && source /u/rkie/venvs/venv_gridfm-graphkit/bin/activate && pytest ./integrationtests/test_base_set.py"
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gridfm_graphkit-0.0.7.tar.gz.
File metadata
- Download URL: gridfm_graphkit-0.0.7.tar.gz
- Upload date:
- Size: 67.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
364c3c0b2a1e7ec6a968a2a578ef34fcad7e1dd5ac807a8859704200f94f935a
|
|
| MD5 |
717f0d40b2990cf01c6ad6c7be9a1596
|
|
| BLAKE2b-256 |
b109fa3fe8d6e1f72fa62214943143be201b663955d4772eb61f0a532fd46a9c
|
File details
Details for the file gridfm_graphkit-0.0.7-py3-none-any.whl.
File metadata
- Download URL: gridfm_graphkit-0.0.7-py3-none-any.whl
- Upload date:
- Size: 74.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ccd629016ba129386df6fcf8cbae4493568be695484afd34b07733e56b74884
|
|
| MD5 |
29f1d3cde40ecfeb9ae0dfc3b69be8d0
|
|
| BLAKE2b-256 |
0fa34bcca27b7e6045def9c1a058ba62567a3aa4faab9bb70e502ef8ea070002
|