A zero-copy multiprocess dataloader for JAX.
Project description
Hydrax 🐉
A zero-copy multiprocess dataloader for JAX. Built for Project RedRocket.
Installation
pip install hydrax
If you want to use any of the extra modules,
hydrax.image,
hydrax.tqdm, or
hydrax.pandas,
you can use pip install hydrax[image,tqdm,pandas]
(or pip install hydrax[all]
).
Wheels are currently available for Linux x86_64 with CPython 3.10, 3.11, and 3.12.
Ensure you have JAX installed and working. If you install via pip, the latest version of JAX will be installed if it is not already, but jaxlib will not be.
From Source
Clone this repository, install Python development files and a C compiler, and run:
source path/to/your/venv/bin/activate
python -m build --wheel
pip install 'dist/hydrax-<...>.whl'
Documentation
Read the online documentation for the latest version at https://redhottensors.github.io/hydrax/.
For local HTML documentation, run make html
in /sphinx
and browse the generated Sphinx documentation in
_build/html/index.html
. You will need pip install furo
, which should also install Sphinx.
Usage
from hydrax import Dataloader, DataGroup, TrainingBatch, ValidationBatch
def my_loader(data, arrays, seed):
# load data from data source into arrays, optionally augmenting using 'seed'.
# if 'seed' is None this is a data from a validation batch
# return any additional data for the batch
if __name__ == "main":
my_data = ...
array_defs = {
"array_name": ((dim, ...), numpy_dtype, jax_dtype),
...
}
train = DataGroup(batch_size, my_data[1000:], loader_arrays=array_defs)
valid = DataGroup(batch_size, my_data[:1000], loader_arrays=array_defs)
dataloader = Dataloader(
my_loader,
train,
validation = ("epoch", 1, valid), # run validation after every epoch
end_at = ("epoch", 5) # run 5 epochs in total
)
with dataloader: # a with block is required
# consider using hydrax.tqdm.tbatches instead of a vanilla for loop here
for batch in dataloader:
if isinstance(batch, TrainingBatch):
run_training_batch(batch)
elif isinstance(batch, ValidationBatch):
run_validation_batch(batch)
del batch # important, release batch before waiting for next one or cleaning up
# with hydrax.tqdm.tbatches
from hydrax.tdqm import tbatches
for batch in tbatches(dataloader, report_interval=1000): # tbatches includes a with block for the dataloader
...
del batch # important, see above
Deadlocks / Stalls
If you are experiencing deadlocks as a result of retaining batch or array references between iterations, consider using
debug_batch_references or
gc.get_referrers to find out what's holding on to your
batches, though do keep in mind that JAX dispatch will retain references while running ahead. You can check your work by
running the Dataloader with depth = 1
, which will immediately deadlock if the first batch is not properly released.
Batch Structure
In Hydrax, a single Dataloader is usually responsible for producing both your training and validation batches, in order to conserve resources and ensure perfectly smooth loading throughout.
Each batch produced by the Dataloader is either a TrainingBatch instance or a ValidationBatch instance, which both inherit the common functionality of Batch. (You can click any of the preceding links to view the online documentation.)
The most important properties of a Batch
are:
arrays
--{ 'array_name': jax.Array, ... }
, corresponding to each array defined by the sourceDataGroup
. The first dimension of the array is the batch size.additional
--{ 'key': [item_0_value, ...] }
, corresponding to any additional data returned by your loader function. Each list'slen
is the batch size. If no corresponding item was returned, the element isNone
. Useget_additional(key[, index])
if your loader sometimes omits returning certain keys.data
-- A proxy type to the original data descriptors for each item, with length equal to the batch size.
As mentioned above, remember to release any references to a batch or its arrays as soon as you're done with them.
Loader Processes
Read the documentation for loader_func
carefully. If you receive a warning from Hydrax about your loader, you should fix your code. Failure to do this could
result in your batch data changing out from underneath you, leading to significant training issues such as NaNs.
Do not attempt to construct a Dataloader inside a loader process. Ensure your training code is guarded with
if __name__ == '__main__':
, or is otherwise prevented from running. As a last resort, you can check
hydrax.is_worker
and bail.
KeyboardInterrupt (Ctrl+C / SIGINT)
The Dataloader installs a handler for KeyboardInterrupt
(Ctrl+C / SIGINT) which stops the flow of batches as soon as
possible. After the dataloader has completed, you can check if this occurred by reading its interrupted
property.
You may want to save a checkpoint along with the numbers of the current epoch and batch, so that you can resume from
where you left off with start_at
.
If you send a second KeyboardInterrupt
, Hydrax will raise a KeyboardInterrupt
at the beginning of the next
batch. This exception may cause you to lose progress unless you or a framework takes care to save a checkpoint in
response.
If you send a third KeyboardInterrupt
, the Python interpreter is immediately stopped and control is returned to you.
You will lose all progress since the last checkpoint.
Compatibility
A convienient wrapper with tqdm progress bars is provided in
hydrax.tqdm. The corresponding extra is
tqdm
.
ICC-profile-aware 8bbp image loading with Pillow is provided in
hydrax.image, and support is included for
Oklab as well. The corresponding extra is image
.
Compatibility for Pandas datasets is provided by
hydrax.pandas. The corresponding extra
is pandas
.
License
Hydrax is available under the terms of the Mozilla Public License, version 2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
File details
Details for the file hydrax-0.2.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl
.
File metadata
- Download URL: hydrax-0.2.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl
- Upload date:
- Size: 56.4 kB
- Tags: CPython 3.12, manylinux: glibc 2.5+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd83ac712798f1b9cd02394b20ace9f9713feb9063174e4aea6cf3d2cc5aeb26 |
|
MD5 | 75d5e04ea40a500d00e6fd8265c06b25 |
|
BLAKE2b-256 | 9503192ddda4dfd01b208df998e5fdf04d56c23a9f0eebde2c8b258a66c64410 |
File details
Details for the file hydrax-0.2.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl
.
File metadata
- Download URL: hydrax-0.2.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl
- Upload date:
- Size: 55.7 kB
- Tags: CPython 3.11, manylinux: glibc 2.5+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e145b665848f85144f55e101ac603ec753032763edd105f2f64e87a2236460b8 |
|
MD5 | f67506983ca7c87ad82284f1f168df5a |
|
BLAKE2b-256 | 92d38783e76ef821a58eab5397f047bfc49405f358568d8cab4861ffb6082288 |
File details
Details for the file hydrax-0.2.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl
.
File metadata
- Download URL: hydrax-0.2.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl
- Upload date:
- Size: 55.2 kB
- Tags: CPython 3.10, manylinux: glibc 2.5+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a60eb8e87364996db724ecd1e04245c8f15db6cade421e0f01f24849ff6bf1d6 |
|
MD5 | e9ae9af4509e487a0180109be3518be8 |
|
BLAKE2b-256 | ba5585f7eca49ac130df1ea9a391bc5a92bc79cde1c347cb661c1e0deb8bf563 |