No project description provided
Project description
Simple Xarray + JAX Integration
This is an experiment at integrating Xarray + JAX in a simple way, leveraging equinox.
import jax.numpy as jnp
import xarray as xr
import xarray_jax as xj
# Construct a DataArray.
da = xr.DataArray(
xr.Variable(["x", "y"], jnp.ones((2, 3))),
coords={"x": [1, 2], "y": [3, 4, 5]},
name="foo",
attrs={"attr1": "value1"},
)
# Do some operations inside a JIT compiled function.
@eqx.filter_jit
def some_function(data):
neg_data = -1.0 * data
return neg_data * neg_data.coords["y"] # Multiply data by coords.
da = some_function(da)
# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, data)
# Use jax.grad.
@eqx.filter_jit
def fn(data):
return (data**2.0).sum().data
grad = jax.grad(fn)(da)
# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)
# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)
Installation
pip install xarray_jax
Status
- PyTree node registrations
-
xr.Variable -
xr.DataArray -
xr.Dataset
-
- Minimal shadow types implemented as equinox modules to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).
-
XjVariable -
XjDataArray -
XjDataset
-
-
xj.from_xarrayandxj.to_xarrayfunctions to go betweenxjandxrtypes. - Support for
xrtypes with dummy data (useful for tree manipulation). - Support for transformations that change the dimensionality of the data.
Sharp Edges
Prefer eqx.filter_jit over jax.jit
There are some edge cases with metadata that eqx.filter_jit handles but jax.jit does not.
Operations that Increase the Dimensionality of the Data
Operations that increase the dimensionality of the data (e.g. jnp.expand_dims) will cause problems downstream.
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))
# This will not error.
var = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)
# The error from expanding the dimensionality will be triggered here.
var = var + 1
Dispatching to jnp is not supported yet
Pending resolution of https://github.com/pydata/xarray/issues/7848.
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))
# This will fail.
jnp.square(var)
# This will work.
xr.apply_ufunc(jnp.square, var)
Distinction from the GraphCast Implementation
This experiment is largely inspired by the GraphCast implementation, with a direct re-use of the _HashableCoords in that project.
However, this experiment aims to:
- Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).
- Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.
Acknowledgements
This repo was made possible by great discussions within the JAX + Xarray open source community, especially this one. In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file xarray_jax-0.0.5.tar.gz.
File metadata
- Download URL: xarray_jax-0.0.5.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.10.12 Linux/6.8.0-45-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
54cf8f9832d5ff50f8798fc385555f4d6cd019a2e416127323606b1f4498f485
|
|
| MD5 |
b9a30c117050ebe43f59dcb993076572
|
|
| BLAKE2b-256 |
daa1aefd27e63c03b0811468b8e979ad17cece1ee2d838b05f4e8f8c910130f2
|
File details
Details for the file xarray_jax-0.0.5-py3-none-any.whl.
File metadata
- Download URL: xarray_jax-0.0.5-py3-none-any.whl
- Upload date:
- Size: 10.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.10.12 Linux/6.8.0-45-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
34ac654b2566cc80dc13b3d8ab05b0c9ee00a7aec8d89d688faacdbe07201e75
|
|
| MD5 |
57256e227e84366a1fa2cf2e0d38534f
|
|
| BLAKE2b-256 |
4e2ec46e09c47eb6fcfc966413a91eead97080f31bdeb7247f6ef52bc5e2b7b8
|