Skip to main content

No project description provided

Project description

Simple Xarray + JAX Integration

This is an experiment at integrating Xarray + JAX in a simple way, leveraging equinox.

import jax.numpy as jnp
import xarray as xr
import xarray_jax as xj

# Construct a DataArray.
da = xr.DataArray(
    xr.Variable(["x", "y"], jnp.ones((2, 3))),
    coords={"x": [1, 2], "y": [3, 4, 5]},
    name="foo",
    attrs={"attr1": "value1"},
)

# Do some operations inside a JIT compiled function.
@eqx.filter_jit
def some_function(data):
    neg_data = -1.0 * data
    return neg_data * neg_data.coords["y"] # Multiply data by coords.

da = some_function(da)

# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, data)

# Use jax.grad.
@eqx.filter_jit
def fn(data):
    return (data**2.0).sum().data

grad = jax.grad(fn)(da)

# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)

# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)

Installation

pip install xarray_jax

Status

  • PyTree node registrations
    • xr.Variable
    • xr.DataArray
    • xr.Dataset
  • Minimal shadow types implemented as equinox modules to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).
    • XjVariable
    • XjDataArray
    • XjDataset
  • xj.from_xarray and xj.to_xarray functions to go between xj and xr types.
  • Support for xr types with dummy data (useful for tree manipulation).
  • Support for transformations that change the dimensionality of the data.

Sharp Edges

Prefer eqx.filter_jit over jax.jit

There are some edge cases with metadata that eqx.filter_jit handles but jax.jit does not.

Operations that Increase the Dimensionality of the Data

Operations that increase the dimensionality of the data (e.g. jnp.expand_dims) will cause problems downstream.

var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will not error.
var = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)

# The error from expanding the dimensionality will be triggered here.
var = var + 1 

Dispatching to jnp is not supported yet

Pending resolution of https://github.com/pydata/xarray/issues/7848.

var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))

# This will fail.
jnp.square(var)

# This will work.
xr.apply_ufunc(jnp.square, var)

Distinction from the GraphCast Implementation

This experiment is largely inspired by the GraphCast implementation, with a direct re-use of the _HashableCoords in that project.

However, this experiment aims to:

  1. Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).
  2. Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.

Acknowledgements

This repo was made possible by great discussions within the JAX + Xarray open source community, especially this one. In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xarray_jax-0.0.5.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xarray_jax-0.0.5-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file xarray_jax-0.0.5.tar.gz.

File metadata

  • Download URL: xarray_jax-0.0.5.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.10.12 Linux/6.8.0-45-generic

File hashes

Hashes for xarray_jax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 54cf8f9832d5ff50f8798fc385555f4d6cd019a2e416127323606b1f4498f485
MD5 b9a30c117050ebe43f59dcb993076572
BLAKE2b-256 daa1aefd27e63c03b0811468b8e979ad17cece1ee2d838b05f4e8f8c910130f2

See more details on using hashes here.

File details

Details for the file xarray_jax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: xarray_jax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.10.12 Linux/6.8.0-45-generic

File hashes

Hashes for xarray_jax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 34ac654b2566cc80dc13b3d8ab05b0c9ee00a7aec8d89d688faacdbe07201e75
MD5 57256e227e84366a1fa2cf2e0d38534f
BLAKE2b-256 4e2ec46e09c47eb6fcfc966413a91eead97080f31bdeb7247f6ef52bc5e2b7b8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page