Skip to main content

Machine learning weather nowcasting library

Project description

mlcast

⚠️ This package is under active development. The API and functionality are subject to change until the v1.0.0 release.

The MLCast Community is a collaborative effort bringing together meteorological services, research institutions, and academia across Europe to develop a unified Python package for AI-based nowcasting. This is an initiative of the E-AI WG6 (Nowcasting) of EUMETNET.

This repo contains the mlcast package for machine learning-based weather nowcasting.

Installation

As mlcast is in rapid development — the recommended path is to clone locally, rather than installing a pinned release from PyPI.

Local development: clone and install locally with uv

Fork the repository on GitHub first, then clone your fork. This lets you track upstream changes while keeping your own modifications on a separate branch:

git clone https://github.com/<your-github-username>/mlcast
cd mlcast

# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh

# And install depencies, depending on whether you have a GPU available.
# CPU
uv sync

# GPU — CUDA 12.8
uv sync --extra gpu-cu128

# GPU — CUDA 13.0
uv sync --extra gpu-cu130

Next you can jump to using mlcast or, if you intend to modify the code, setup the development toolchain as described below:

# Install dev dependencies
uv sync --extra dev

# Install the pre-commit git hook (runs checks automatically on every commit)
uv run pre-commit install

PyPI release

Tagged releases are published to PyPI and can be installed with pip:

pip install mlcast

For active development or access to unreleased changes, clone the repository and install locally with uv as described above.

Usage

mlcast exposes two interfaces for training: a command-line interface (CLI) for interactive and scripted use, and a Python API for programmatic control. Both are built on Fiddle — a configuration library that lets you build a full experiment graph, override any parameter, and reproduce runs exactly from a saved YAML file.

Configuration model

Training in mlcast is currently built around a single base configuration function, training_experiment, which defines the default ConvGRU ensemble nowcasting setup: dataset, data module, network, Lightning module, and trainer. Rather than writing a new config from scratch, the intended workflow is to start from this base and apply targeted modifications:

  • set: overrides — change a single scalar parameter (e.g. batch size, learning rate, number of epochs)
  • fiddlers — apply a named mutator function that keeps multiple related parameters in sync (e.g. switching the dataset class, toggling masking, changing the logger)
  • direct graph edits (Python API only) — replace a sub-object entirely, for example swapping in a different network architecture

Any combination of these can be layered on top of the base config, and the fully resolved config is always saved to YAML alongside the training logs so runs can be reproduced exactly.

The diagram below shows the full default config graph as built by training_experiment:

training_experiment config graph

Command-line interface

Install the package and run:

mlcast train

This trains with the built-in training_experiment defaults. All parameters are controlled via --config flags:

Prefix Purpose Example
(none) Use the built-in default config mlcast train
set: Override a single parameter --config set:data.batch_size=32
fiddler: Apply a semantic mutator (multi-param change) --config fiddler:use_random_sampler
config: Switch to a different @auto_config function --config=config:my_experiment
path/to/config.yaml Load a previously saved config --config saved.yaml

Multiple --config flags are applied in order and can be combined freely.

Examples:

# Override dataset path and batch size
mlcast train \
    --config set:data.dataset_factory.zarr_path=/data/radar.zarr \
    --config set:data.batch_size=32

# Switch to random sampler and log to MLflow
mlcast train \
    --config fiddler:use_random_sampler \
    --config fiddler:use_mlflow_logger

# Resume from a saved config with an epoch override
mlcast train \
    --config logs/mlcast/version_0/config.yaml \
    --config set:trainer.max_epochs=50

# Inspect the fully resolved config without starting training
mlcast train --config fiddler:use_random_sampler --print_config_and_exit

Run mlcast train --help for a full list of examples and available fiddlers.

Python API

The Python API gives you full programmatic control over the config graph before anything is instantiated.

Run the default experiment with tweaks:

import fiddle as fdl
from mlcast.config import training_experiment, train_from_config
from mlcast.config.fiddlers import use_random_sampler

cfg = training_experiment.as_buildable()  # returns a fdl.Config graph — see src/mlcast/config/base.py

# Apply a fiddler to switch the dataset sampler
use_random_sampler(cfg)

# Override individual parameters directly on the config graph
cfg.data.batch_size = 32
cfg.trainer.max_epochs = 50

# Validates cross-parameter contracts, builds all objects, persists config
# YAML to the active logger, then calls trainer.fit() + trainer.test()
train_from_config(cfg)

Custom network architecture:

You can swap in any architecture by replacing cfg.pl_module.network with a fdl.Config node. The network must implement the nowcasting forward interface — see Custom network interface below.

As an example, here is how to wrap an mfai HalfUNet (a plain single-step U-Net) to satisfy the interface. The wrapper channel-stacks the past frames and runs the U-Net autoregressively for each requested forecast step:

Noteinput_steps equals dataset_factory.input_steps (6 by default) and is directly readable from the config graph before building.

import einops
import fiddle as fdl
import torch
import torch.nn as nn
from jaxtyping import Float
from mfai.torch.models import HalfUNet
from mlcast.config import training_experiment, train_from_config
from mlcast.config.fiddlers import use_random_sampler

# Minimal adapter: channel-stack past frames → HalfUNet → one step at a time.
# NowcastLightningModule calls network(x, steps=N, ensemble_size=M), so any
# custom network must accept those keyword arguments.
class HalfUNetNowcaster(nn.Module):
    def __init__(self, input_steps: int = 6, num_vars: int = 1):
        super().__init__()
        self.input_steps = input_steps
        self.num_vars = num_vars
        self.unet = HalfUNet(
            input_shape=(256, 256),
            in_channels=input_steps * num_vars,
            out_channels=num_vars,
            settings=fdl.Config(HalfUNet.settings_kls),
        )

    @property
    def input_channels(self) -> int:
        # Externally, the HalfUNetNowcaster respects the required input shape structure
        # (batch, input_steps, num_vars, H, W), even though the internal U-Net is channel-stacked.
        # Adding this property allows the config consistency checks to verify that
        # the dataset and model agree on the expected number of input channels.
        return self.num_vars

    def forward(
        self,
        x: Float[torch.Tensor, "batch input_steps in_channels H W"],
        steps: int,
        ensemble_size: int = 1,
    ) -> Float[torch.Tensor, "batch steps out_channels H W"]:
        # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w)
        x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w")
        preds = []
        for _ in range(steps):
            y = self.unet(x_flat)   # [B, num_vars, H, W]
            preds.append(y.unsqueeze(1))
            # slide window: drop the oldest timestep (first num_vars channels),
            # append the latest prediction as the newest timestep
            x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1)
        return torch.cat(preds, dim=1)

cfg = training_experiment.as_buildable()
use_random_sampler(cfg)

cfg.pl_module.network = fdl.Config(
    HalfUNetNowcaster,
    input_steps=cfg.data.dataset_factory.input_steps,
    num_vars=len(cfg.data.dataset_factory.standard_names),
)

train_from_config(cfg)

For lower-level control you can call the steps of train_from_config individually:

import fiddle as fdl
from mlcast.config.consistency_checks import validate_config

validate_config(cfg)          # raises ValueError on any contract violation
experiment = fdl.build(cfg)   # instantiates all objects
experiment.run()              # trainer.fit() + trainer.test()

Available fiddlers

Fiddler Arguments What it does
use_mlflow_logger (none) Replaces the default TensorBoardLogger with MLFlowLogger and appends LogSystemInfoCallback; respects the MLFLOW_TRACKING_URI environment variable
set_variables standard_names Sets the list of input variables on the dataset and updates network.input_channels to match
toggle_masking enabled Toggles masked-loss mode by setting both dataset_factory.return_mask and pl_module.masked_loss to the same value
use_anon_s3_dataset zarr_path, endpoint_url Points the dataset at an anonymous S3 object store; sets zarr_path and the required storage_options together
use_random_sampler (none) Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available)

Project Structure

mlcast/
├── src/mlcast/
│   ├── __main__.py                      # CLI entry point (mlcast train)
│   ├── nowcasting_module.py             # Generic Lightning module for nowcasting
│   ├── losses.py                        # CRPS, AFCRPS, MSE loss functions
│   ├── callbacks.py                     # Training callbacks
│   ├── visualization.py                 # TensorBoard image logging helpers
│   ├── config/
│   │   ├── base.py                      # Default training_experiment @auto_config
│   │   ├── fiddlers.py                  # Semantic config mutators
│   │   ├── consistency_checks.py        # Cross-parameter validation
│   │   ├── loader.py                    # YAML config loader
│   │   └── orchestrator.py             # train_from_config, config persistence
│   ├── data/
│   │   ├── source_data_datamodule.py    # Lightning DataModule
│   │   ├── source_data_datasets.py      # Zarr-backed PyTorch datasets
│   │   └── normalization.py             # Normalisation registry
│   └── models/
│       └── convgru.py                   # ConvGRU encoder-decoder
├── tests/
├── pyproject.toml
└── README.md

Implemented architectures

ConvGruModel

ConvGruModel (in src/mlcast/models/convgru.py) is an encoder-decoder architecture. It is not autoregressive at forecast time: rather than generating each forecast frame from the previous predicted frame, the decoder performs a temporal roll-out entirely in latent space — the ConvGRU at each spatial scale unrolls over forecast_steps steps driven by noise or zeros, with its hidden state initialised from the encoder. Forecast frames are only materialised at the end, by upsampling the final decoder hidden states back to the original spatial resolution.

Encoding — a stack of EncoderBlock layers unrolls a ConvGRU sequentially over the input_steps real observed frames. Each block halves the spatial resolution via PixelUnshuffle(2). The last hidden state of each block is retained.

Decoding — a stack of DecoderBlock layers performs a latent-space roll-out at each spatial scale. Each decoder block's ConvGRU is initialised with the final hidden state from the corresponding encoder block, then unrolls over forecast_steps steps with noise or zeros as input — so the forecast sequence emerges from the evolution of hidden states across multiple spatial scales, never from feeding predictions back as inputs. Spatial resolution is doubled at each block via PixelShuffle(2).

Ensemble — when ensemble_size > 1 the decoder is run ensemble_size times, each time with freshly sampled Gaussian noise. The results are concatenated along the channel dimension.

Deterministic variant (diagram source):

ConvGruModel deterministic architecture

Stochastic / ensemble variant (diagram source):

ConvGruModel stochastic architecture

Custom network interface

Any network architecture can be used by replacing cfg.pl_module.network with a fdl.Config node pointing at your class. The only requirement is that forward accepts the following signature:

# from jaxtyping import Float
# import torch

def forward(
    self,
    x: Float[torch.Tensor, "batch input_steps in_channels H W"],
    steps: int,          # number of forecast steps to produce
    ensemble_size: int,  # number of stochastic ensemble members
) -> Float[torch.Tensor, "batch steps out_channels H W"]:
    ...

If your network uses a different parameter name for the input channel count than input_channels (the default assumed by ConvGruModel and the set_variables fiddler), set it explicitly on the config node.

Contributing

Please feel free to raise issues or PRs if you have any suggestions or questions.

Links to presentations for discussion about the API

License

This project is dual-licensed under either:

at your option.

See LICENSE for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlcast-0.1.0.tar.gz (663.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlcast-0.1.0-py3-none-any.whl (54.2 kB view details)

Uploaded Python 3

File details

Details for the file mlcast-0.1.0.tar.gz.

File metadata

  • Download URL: mlcast-0.1.0.tar.gz
  • Upload date:
  • Size: 663.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for mlcast-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9f92505df760ca47188f5b1485d40e904aa3f9cb29863d23ce024c1cd3343309
MD5 1fbdba1bb71818057fa68a3c28e20369
BLAKE2b-256 46208386afb0afb5e8df92ca367b345dd3dfb100addd63e1a9d4853181867f73

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlcast-0.1.0.tar.gz:

Publisher: pypi-release.yml on mlcast-community/mlcast

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mlcast-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mlcast-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 54.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.13

File hashes

Hashes for mlcast-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ca6dd21cb355ddb721819bb5b9e7738cdf3b9287b6eeb37ce8912c70c4b4098f
MD5 5ece47652bfa6040c8446bcddfaf5378
BLAKE2b-256 d7ab04ca245dd781f951ea1c211a08c66554930cb65e66283d255631237d1ec6

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlcast-0.1.0-py3-none-any.whl:

Publisher: pypi-release.yml on mlcast-community/mlcast

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page