Skip to main content

A dascore compatibility layer for JAX.

Project description

dasjax

An experimental package for accelerating DASCore with JAX.

Installation

python -m pip install -e ".[dev]"

Usage

dasjax's main feature is the ability to create compiled DAS pipelines that can run on CPU, GPU, or TPU. These also perform kernel fusions for increased efficiency.

Compiled pipeline

Use JaxPatchPipeline when you want to compile a reusable sequence once and run it across many compatible patches.

import dascore as dc
from dasjax import JaxPatchPipeline

patch = dc.get_example_patch("example_event_1")

pipeline = (
    JaxPatchPipeline()
    .scale(2.0)
    .add(1.0)
    .detrend(dim="time", type="constant")
    .normalize(dim="time")
)
compiled = pipeline.compile()

out = patch.pipe(compiled)

print(out.shape)

Development

Three-Tier Architecture

dasjax is organized as a small three-tier stack:

  1. Pipeline layer: src/dasjax/pipeline.py records operation chains and compiles reusable patch transforms. This is the main user-facing API.
  2. Operation and pipeline layer: src/dasjax/operations.py defines the operation registry, validation rules, internal eager implementations, and compiled behavior.
  3. Kernel layer: src/dasjax/kernels.py contains the array-level JAX and callback-backed kernels that actually do the numerical work.

This split keeps the package easier to extend: add or update numerical behavior in the kernel layer, describe how it plugs into compiled execution in the operation layer, and expose it through the pipeline layer.

Roadmap

The table below tracks what is missing and roughly how much effort each addition requires.

Near-term — straightforward pure-JAX array ops

No new infrastructure needed; each maps directly to one or two jnp calls.

Method Implementation notes
real, imag, angle, conj jnp one-liners for complex patches
flip jnp.flip along a named dim
roll jnp.roll circular shift along a dim
pad jnp.pad with DASCore coordinate extension
standardize zero-mean + unit-std (compare normalize)
differentiate jnp.diff finite differences along a dim
integrate jnp.cumsum / trapezoid along a dim
dft / idft jnp.fft.rfft / irfft with coord reconstruction
hilbert / envelope hilbert via FFT; envelope = abs(hilbert(data))
taper / taper_range hann / cosine windows broadcast along axis
whiten spectral divide-by-amplitude via FFT

Medium-term — moderate effort or shape-changing

These need either more work in kernels.py or are shape-changing (segmented pipeline execution, same mechanism as fbe).

Method Implementation notes
notch_filter SOS filter; same pattern as pass_filter
savgol_filter polynomial fitting per frame; JAX-doable
rolling rolling-window reductions (mean, std, …); needs strided views
correlate cross-correlation via jnp.fft
stft / istft expose the STFT kernel already used by fbe
decimate anti-aliased downsampling; shape-changing
aggregate / mean / std / sum axis reductions; shape-changing

Performance Notes

  • The intended fast path is to build a JaxPatchPipeline, call .compile() once, and reuse the returned callable across many patches of compatible shape and dtype.
  • Equivalent pipeline definitions reuse cached compiled callables automatically.
  • Benchmarks live under benchmarks/ and compare compiled dasjax pipelines against equivalent DASCore operation chains.

Development Guidelines

  • Add new JAX patch methods by defining an array kernel in src/dasjax/kernels.py and one operation spec in src/dasjax/operations.py.
  • The operation spec is the single source of truth for pipeline support, validation, and shared parity test cases.
  • Every new patch method must be tested against a DASCore baseline across the shared mixed-patch fixture in tests/conftest.py.
  • Prefer comparing internal operation behavior and compiled pipeline outputs against the closest native DASCore method or operator. If DASCore has no direct method, compare against an equivalent Patch.update(...) baseline.
  • Method-equivalence assertions should check data closeness with equal_nan=True when needed and should also verify coordinate preservation.
  • Compiled pipeline parity should come from the same declared operation cases rather than a separate hand-maintained test matrix.
  • Install Git hooks locally with prek install.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dasjax-0.0.2.tar.gz (24.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dasjax-0.0.2-py3-none-any.whl (21.3 kB view details)

Uploaded Python 3

File details

Details for the file dasjax-0.0.2.tar.gz.

File metadata

  • Download URL: dasjax-0.0.2.tar.gz
  • Upload date:
  • Size: 24.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dasjax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0201d1a99df990fd1d043da6fa4de043e33b745010199d71d6be10c56e5ec054
MD5 6673c16f13f4399849b3dc56db628c06
BLAKE2b-256 6971180dc4740706a56e582d2682fac0d74b4729512b147dd35075d2e1e5e3ec

See more details on using hashes here.

File details

Details for the file dasjax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: dasjax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 21.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dasjax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 369852defbac7bf0d78e1d1b1bfe6e9c5fb998e37aa903c75e3b38e6ae6687c4
MD5 22668b6a2389b0fc6ab5b7dc8dda7d60
BLAKE2b-256 e71f9f2c2ad0cd602eb4d46075a5bdafae08ba77139dbf5a8b4985ae1d96bd1d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page