Skip to main content

Idiomatic-PyTorch rewrite of the Dinosaur spectral dynamical core 🦖⚡

Project description

dinosaur-torch

dinosaur-torch

PyPI Python 3.11+ PyTorch GitHub CI License: Apache 2.0

An idiomatic-PyTorch rewrite of Dinosaur, the spectral dynamical core behind NeuralGCM. Rather than a line-for-line JAX→PyTorch translation, this package is written the way a PyTorch library would be written from scratch, while staying numerically equivalent to the original.

Numerically validated against the original JAX Dinosaur.

Design

  • Tensors in, tensors out. Functions and modules operate on torch.Tensors. NumPy appears only at I/O and construction boundaries (grid/quadrature setup, xarray conversion). There is no asarray promotion at every call site, no global default-device convention, and no host-constant device cache.
  • Precomputed constants live in nn.Module buffers. Objects that hold tensors (spectral transforms, the dycore) are torch.nn.Modules with non-persistent buffers, so .to(device) / .float() work the standard way and state_dict() contains only learned parameters (none, for the dycore).
  • Static metadata is separate from tensors. GridSpec, SigmaCoordinates etc. are frozen dataclasses — hashable, comparable, cheap — used to construct the tensor-holding modules.
  • States are torch pytrees. Model state (State, diagnostics, …) is a plain dataclass registered via torch.utils._pytree.register_dataclass, so it composes natively with torch.compile, torch.func, and CUDA graphs. No custom pytree registry.
  • Standard test style: plain pytest with parametrization (no absl/parameterized).
  • Scope: the primitive-equations path used by NeuralGCM (transforms, sigma coordinates, primitive equations, IMEX time integration, filtering, vertical/horizontal interpolation, data utilities). Shallow-water and Held–Suarez model families are intentionally not ported: no published NeuralGCM checkpoint uses them.

Layout

module contents
associated_legendre.py, fourier.py basis construction (NumPy, at setup time)
spherical_harmonic.py GridSpec (static), RealSphericalHarmonics / FastSphericalHarmonics transforms, Grid (nn.Module: transforms + spectral operators)
sigma_coordinates.py SigmaCoordinates (static) + SigmaLevels (nn.Module: vertical finite-difference / integral operators)
coordinate_systems.py CoordinateSystem (nn.Module: horizontal × vertical), spectral up/downsampling
primitive_equations.py State (torch-pytree dataclass), PrimitiveEquations (nn.Module IMEX ODE, dry/moist/cloud variants), Geopotential
time_integration.py IMEX Runge-Kutta steppers (SIL3, CN-RK2/3/4, Euler), step filters, trajectories (plain loops), digital filter initialization
filtering.py exponential / horizontal-diffusion spectral filters
vertical_interpolation.py PressureCoordinates / PressureLevels, pressure ↔ sigma regridding (batched searchsorted/gather, no vmap)
horizontal_interpolation.py conservative / bilinear / nearest lat-lon regridders (weights precomputed as buffers)
radiation.py top-of-atmosphere incident solar radiation (SolarRadiation module)
scales.py, units.py unit handling / nondimensionalization (NumPy + pint)
xarray_utils.py ERA5-style dataset preparation: regrid_horizontal, fill_nan_with_nearest, selective_temporal_shift, grid_spec_from_dataset
pytree.py tiny helpers over torch.utils._pytree

Both spherical-harmonics layouts are implemented because published NeuralGCM checkpoints use both: RealSphericalHarmonics (modal shape (2M-1, L), the 2.8° deterministic checkpoint) and FastSphericalHarmonics (zero-imag layout, modal shape (2M, L), e.g. the TL63 stochastic checkpoint; named RealSphericalHarmonicsWithZeroImag upstream).

Status

The dycore and data path are complete and numerically validated against the original JAX implementation — transforms, operators, the full primitive-equations step (dry and moist, including a 10-step baroclinic-wave trajectory), vertical/horizontal regridding, and solar radiation all match to 1e-5–1e-4 of each field's range — alongside 141 unit tests (pytest). A full SIL3 time step compiles with torch.compile(fullgraph=True) out of the box — no shim rework, no graph breaks.

Not ported: shallow water, Held–Suarez, hybrid coordinates, and leapfrog steppers (intentionally out of scope — no published NeuralGCM checkpoint uses them).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dinosaur_torch-0.1.0.tar.gz (72.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dinosaur_torch-0.1.0-py3-none-any.whl (64.9 kB view details)

Uploaded Python 3

File details

Details for the file dinosaur_torch-0.1.0.tar.gz.

File metadata

  • Download URL: dinosaur_torch-0.1.0.tar.gz
  • Upload date:
  • Size: 72.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.19 {"installer":{"name":"uv","version":"0.11.19","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dinosaur_torch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a13af848f0d937c8839d9ddc60790340e6c0eae28829d29a2cdea820f5d74edc
MD5 c04f49777ff9c95c91acca7d5a80717c
BLAKE2b-256 ad6c4d6e4bd835be669d909b6b4ce7a3a0db361158ac5a1fad36b1d8db57c7bd

See more details on using hashes here.

File details

Details for the file dinosaur_torch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: dinosaur_torch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 64.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.19 {"installer":{"name":"uv","version":"0.11.19","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dinosaur_torch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 14146a3f78fbd9d5fc2a370f933de8f1b0ecf6cab5e228701a2f5b51443db7f8
MD5 c9100c6adc9def4840cabdefc2688a39
BLAKE2b-256 f060d84546f422e4117864258ee7513545c811c7e88c9f5bf637ae6cfe4f4c30

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page