Physics-guided flow models for weather prediction
Project description
WeatherFlow: A Deep Learning Library for Weather Prediction
Introduction
WeatherFlow is a Python library built on PyTorch that aims to provide a flexible and extensible framework for developing and evaluating deep learning models for weather prediction. It leverages recent advancements in flow matching and incorporates design principles for handling spatiotemporal data, particularly geopotential height fields. This documentation covers the package structure, API, usage examples, and development guidelines.
Key Features
- Modular Design: The library is organized into modules for data loading, model definition, training, evaluation, and visualization, making it easy to extend and customize.
- ERA5 Data Integration: Includes a
WeatherDatasetclass for easy loading and preprocessing of ERA5 reanalysis data, a standard dataset for weather prediction research. It supports both local netCDF files and direct loading from the WeatherBench 2 Google Cloud Storage (requires authentication). - Flow Matching Models: Implements a
WeatherFlowMatchmodel that utilizes the principles of flow matching for generating weather predictions. - Physics-Guided Architecture (Planned): Future versions will incorporate physics-informed components.
- Configurable Training: Provides a flexible
train_modelfunction with options for various optimizers, learning rate schedulers, and early stopping. - Comprehensive Evaluation: Includes functions for calculating standard weather prediction metrics (RMSE, ACC) and visualizing predictions.
- Extensible Design: The modular structure allows users to easily add custom models, data processing steps, and evaluation metrics.
Installation
Dependencies:
- Python >= 3.8
- torch >= 1.9
- xarray
- numpy
- matplotlib
- cartopy (for visualizations)
- fsspec (for Google Cloud Storage access)
- gcsfs (for Google Cloud Storage access)
- tqdm (for progress bars)
- wandb (optional, for experiment tracking)
- scipy
- netCDF4
- h5py
Quick Start
Here's a quick example of how to load data, train a model, and visualize predictions:
API Reference
weatherflow.data
-
WeatherDataset(data_path, variables, years, input_length=1, lead_time=1):data_path: Path to the directory containing ERA5 netCDF files (one per year).variables: List of variable names (e.g.,['z', 't']for geopotential and temperature). Use the short names.years: A list of years (integers) to include in the dataset, or the string 'all' to include all available .nc files.input_length: The number of timesteps to include in the input sequence (default: 1).lead_time: The number of timesteps between the last input timestep and the target timestep (default: 1, meaning a 6-hour forecast).
-
create_data_loaders(variables, pressure_levels, data_dir, train_years, val_years, batch_size, num_workers): A convenience function to create PyTorchDataLoaderinstances for training and validation. It handles splitting the data by year.
weatherflow.models
WeatherFlowMatch(input_dim, hidden_dim):input_dim: The flattened dimension of a single input timestep (e.g., 64 * 32 = 2048 for a 64x32 grid).hidden_dim: The hidden dimension of the internal layers of the model.forward(self, x, t): Performs a forward pass.xis the input tensor (shape:[batch, channels, height, width]), andtis the time tensor (shape:[batch]). Returns the predicted velocity field.compute_loss: Includes a magnitude loss in addition to the direction loss.
weatherflow.utils
plot_prediction_comparison: Plots the difference between the true and predicted values.create_prediction_animation: Creates an animation of predictions.plot_raw_fields: Plots the data before any transformation.calculate_metrics: Calculates root mean squared error and anomaly correlation coefficient.generate_wb2_predictions: Prepares model output to be in the correct weatherbench2 format.evaluate_saved_predictions: Performs weatherbench2 evaluation.
Advanced Usage
-
Custom Models: You can create your own models by subclassing
nn.Moduleand implementing theforwardmethod. You'll likely want to modify thetrain_modelfunction to work with your custom model. -
Custom Loss Functions: You can define custom loss functions beyond the standard MSE loss.
-
Data Augmentation: Add data augmentation to the
WeatherDataset(e.g., random rotations, flips) to improve model robustness. -
Distributed Training: Adapt the training loop to use PyTorch's distributed training capabilities for larger datasets and models.
-
More Sophisticated Visualization: Use the
WeatherVisualizerclass as a starting point to create more advanced visualizations, such as animations of weather patterns over time, or plots that highlight specific regions or features.
This documentation gives you a solid foundation for using and extending the weatherflow package. Remember to refer to the docstrings within the code for more detailed information on specific functions and classes. Good luck with your weather prediction projects!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file weatherflow-0.2.1.tar.gz.
File metadata
- Download URL: weatherflow-0.2.1.tar.gz
- Upload date:
- Size: 31.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9e7615593f98a6e36fee07881559b6c7c6c2a2661c47ae15439c4a09adae5597
|
|
| MD5 |
3ef0b9078712b53e95978aea60bdeb33
|
|
| BLAKE2b-256 |
f18719fc6d4bed8aa07b9321bec072a2b36b987a50ba10f602592531509fb7a2
|
File details
Details for the file weatherflow-0.2.1-py3-none-any.whl.
File metadata
- Download URL: weatherflow-0.2.1-py3-none-any.whl
- Upload date:
- Size: 39.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
33e35b45244be21428f84321ea2054548c1c6850f988b6ee493aadcb04c7e66d
|
|
| MD5 |
e621213471ac55983a6ef9c80e158f3e
|
|
| BLAKE2b-256 |
c916517ad0328f588e0b4a8b425e8a3be0a0e417e131dfae0e854ead091a0217
|