Idiomatic-PyTorch rewrite of the Dinosaur spectral dynamical core 🦖⚡
Project description
dinosaur-torch
An idiomatic-PyTorch rewrite of Dinosaur, the spectral dynamical core behind NeuralGCM. Rather than a line-for-line JAX→PyTorch translation, this package is written the way a PyTorch library would be written from scratch, while staying numerically equivalent to the original.
Numerically validated against the original JAX Dinosaur.
Design
- Tensors in, tensors out. Functions and modules operate on
torch.Tensors. NumPy appears only at I/O and construction boundaries (grid/quadrature setup, xarray conversion). There is noasarraypromotion at every call site, no global default-device convention, and no host-constant device cache. - Precomputed constants live in
nn.Modulebuffers. Objects that hold tensors (spectral transforms, the dycore) aretorch.nn.Modules with non-persistent buffers, so.to(device)/.float()work the standard way andstate_dict()contains only learned parameters (none, for the dycore). - Static metadata is separate from tensors.
GridSpec,SigmaCoordinatesetc. are frozen dataclasses — hashable, comparable, cheap — used to construct the tensor-holding modules. - States are torch pytrees. Model state (
State, diagnostics, …) is a plain dataclass registered viatorch.utils._pytree.register_dataclass, so it composes natively withtorch.compile,torch.func, and CUDA graphs. No custom pytree registry. - Standard test style: plain
pytestwith parametrization (no absl/parameterized). - Scope: the primitive-equations path used by NeuralGCM (transforms, sigma coordinates, primitive equations, IMEX time integration, filtering, vertical/horizontal interpolation, data utilities). Shallow-water and Held–Suarez model families are intentionally not ported: no published NeuralGCM checkpoint uses them.
Layout
| module | contents |
|---|---|
associated_legendre.py, fourier.py |
basis construction (NumPy, at setup time) |
spherical_harmonic.py |
GridSpec (static), RealSphericalHarmonics / FastSphericalHarmonics transforms, Grid (nn.Module: transforms + spectral operators) |
sigma_coordinates.py |
SigmaCoordinates (static) + SigmaLevels (nn.Module: vertical finite-difference / integral operators) |
coordinate_systems.py |
CoordinateSystem (nn.Module: horizontal × vertical), spectral up/downsampling |
primitive_equations.py |
State (torch-pytree dataclass), PrimitiveEquations (nn.Module IMEX ODE, dry/moist/cloud variants), Geopotential |
time_integration.py |
IMEX Runge-Kutta steppers (SIL3, CN-RK2/3/4, Euler), step filters, trajectories (plain loops), digital filter initialization |
filtering.py |
exponential / horizontal-diffusion spectral filters |
vertical_interpolation.py |
PressureCoordinates / PressureLevels, pressure ↔ sigma regridding (batched searchsorted/gather, no vmap) |
horizontal_interpolation.py |
conservative / bilinear / nearest lat-lon regridders (weights precomputed as buffers) |
radiation.py |
top-of-atmosphere incident solar radiation (SolarRadiation module) |
scales.py, units.py |
unit handling / nondimensionalization (NumPy + pint) |
xarray_utils.py |
ERA5-style dataset preparation: regrid_horizontal, fill_nan_with_nearest, selective_temporal_shift, grid_spec_from_dataset |
pytree.py |
tiny helpers over torch.utils._pytree |
Both spherical-harmonics layouts are implemented because published NeuralGCM
checkpoints use both: RealSphericalHarmonics (modal shape (2M-1, L), the
2.8° deterministic checkpoint) and FastSphericalHarmonics (zero-imag layout,
modal shape (2M, L), e.g. the TL63 stochastic checkpoint; named
RealSphericalHarmonicsWithZeroImag upstream).
Status
The dycore and data path are complete and numerically validated against the
original JAX implementation — transforms, operators, the full
primitive-equations step (dry and moist, including a 10-step baroclinic-wave
trajectory), vertical/horizontal regridding, and solar radiation all match to
1e-5–1e-4 of each field's range — alongside 141 unit tests (pytest). A full
SIL3 time step compiles with
torch.compile(fullgraph=True) out of the box — no shim rework, no graph
breaks.
Not ported: shallow water, Held–Suarez, hybrid coordinates, and leapfrog steppers (intentionally out of scope — no published NeuralGCM checkpoint uses them).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dinosaur_torch-0.1.1.tar.gz.
File metadata
- Download URL: dinosaur_torch-0.1.1.tar.gz
- Upload date:
- Size: 73.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.19 {"installer":{"name":"uv","version":"0.11.19","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e76b10615c7132194305f7844e4f56b74894fcc53fc664856a21ebe5c842a163
|
|
| MD5 |
aae715a76b1b072fe85c49eda97ed025
|
|
| BLAKE2b-256 |
410ed862dbcdc803175d2e094a6d4203eee6477546fe7836469278589fdcfa72
|
File details
Details for the file dinosaur_torch-0.1.1-py3-none-any.whl.
File metadata
- Download URL: dinosaur_torch-0.1.1-py3-none-any.whl
- Upload date:
- Size: 65.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.19 {"installer":{"name":"uv","version":"0.11.19","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c282e9c49dad284b805d449c93a2e59cd005272620071846fbb3346ca0bab0fe
|
|
| MD5 |
aeeddb2d02f644c0484305eb5c0c7961
|
|
| BLAKE2b-256 |
ce3e47fbec417725bf96938ce5dff410fe387f08ce1ac87ac4aca119c680eedb
|