A dascore compatibility layer for JAX.
Project description
dasjax
An experimental package for accelerating DASCore with JAX.
Installation
python -m pip install -e ".[dev]"
Usage
dasjax's main feature is the ability to create compiled DAS pipelines that can run on CPU, GPU, or TPU. These also perform kernel fusions for increased efficiency.
Compiled pipeline
Use JaxPatchPipeline when you want to compile a reusable sequence once and run it across many compatible patches.
import dascore as dc
from dasjax import JaxPatchPipeline
patch = dc.get_example_patch("example_event_1")
pipeline = (
JaxPatchPipeline()
.scale(2.0)
.add(1.0)
.detrend(dim="time", type="constant")
.normalize(dim="time")
)
compiled = pipeline.compile()
out = patch.pipe(compiled)
print(out.shape)
Development
Three-Tier Architecture
dasjax is organized as a small three-tier stack:
- Pipeline layer:
src/dasjax/pipeline.pyrecords operation chains and compiles reusable patch transforms. This is the main user-facing API. - Operation and pipeline layer:
src/dasjax/operations.pydefines the operation registry, validation rules, internal eager implementations, and compiled behavior. - Kernel layer:
src/dasjax/kernels.pycontains the array-level JAX and callback-backed kernels that actually do the numerical work.
This split keeps the package easier to extend: add or update numerical behavior in the kernel layer, describe how it plugs into compiled execution in the operation layer, and expose it through the pipeline layer.
Roadmap
The table below tracks what is missing and roughly how much effort each addition requires.
Near-term — straightforward pure-JAX array ops
No new infrastructure needed; each maps directly to one or two jnp calls.
| Method | Implementation notes |
|---|---|
real, imag, angle, conj |
jnp one-liners for complex patches |
flip |
jnp.flip along a named dim |
roll |
jnp.roll circular shift along a dim |
pad |
jnp.pad with DASCore coordinate extension |
standardize |
zero-mean + unit-std (compare normalize) |
differentiate |
jnp.diff finite differences along a dim |
integrate |
jnp.cumsum / trapezoid along a dim |
dft / idft |
jnp.fft.rfft / irfft with coord reconstruction |
hilbert / envelope |
hilbert via FFT; envelope = abs(hilbert(data)) |
taper / taper_range |
hann / cosine windows broadcast along axis |
whiten |
spectral divide-by-amplitude via FFT |
Medium-term — moderate effort or shape-changing
These need either more work in kernels.py or are shape-changing (segmented pipeline execution, same mechanism as fbe).
| Method | Implementation notes |
|---|---|
notch_filter |
SOS filter; same pattern as pass_filter |
savgol_filter |
polynomial fitting per frame; JAX-doable |
rolling |
rolling-window reductions (mean, std, …); needs strided views |
correlate |
cross-correlation via jnp.fft |
stft / istft |
expose the STFT kernel already used by fbe |
decimate |
anti-aliased downsampling; shape-changing |
aggregate / mean / std / sum |
axis reductions; shape-changing |
Performance Notes
- The intended fast path is to build a
JaxPatchPipeline, call.compile()once, and reuse the returned callable across many patches of compatible shape and dtype. - Equivalent pipeline definitions reuse cached compiled callables automatically.
- Benchmarks live under
benchmarks/and compare compileddasjaxpipelines against equivalent DASCore operation chains.
Development Guidelines
- Add new JAX patch methods by defining an array kernel in
src/dasjax/kernels.pyand one operation spec insrc/dasjax/operations.py. - The operation spec is the single source of truth for pipeline support, validation, and shared parity test cases.
- Every new patch method must be tested against a DASCore baseline across the shared mixed-patch fixture in
tests/conftest.py. - Prefer comparing internal operation behavior and compiled pipeline outputs against the closest native DASCore method or operator. If DASCore has no direct method, compare against an equivalent
Patch.update(...)baseline. - Method-equivalence assertions should check data closeness with
equal_nan=Truewhen needed and should also verify coordinate preservation. - Compiled pipeline parity should come from the same declared operation cases rather than a separate hand-maintained test matrix.
- Install Git hooks locally with
prek install.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dasjax-0.0.2.tar.gz.
File metadata
- Download URL: dasjax-0.0.2.tar.gz
- Upload date:
- Size: 24.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0201d1a99df990fd1d043da6fa4de043e33b745010199d71d6be10c56e5ec054
|
|
| MD5 |
6673c16f13f4399849b3dc56db628c06
|
|
| BLAKE2b-256 |
6971180dc4740706a56e582d2682fac0d74b4729512b147dd35075d2e1e5e3ec
|
File details
Details for the file dasjax-0.0.2-py3-none-any.whl.
File metadata
- Download URL: dasjax-0.0.2-py3-none-any.whl
- Upload date:
- Size: 21.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
369852defbac7bf0d78e1d1b1bfe6e9c5fb998e37aa903c75e3b38e6ae6687c4
|
|
| MD5 |
22668b6a2389b0fc6ab5b7dc8dda7d60
|
|
| BLAKE2b-256 |
e71f9f2c2ad0cd602eb4d46075a5bdafae08ba77139dbf5a8b4985ae1d96bd1d
|