Skip to main content

A zero-copy multiprocess dataloader for JAX.

Project description

Hydrax 🐉

A zero-copy multiprocess dataloader for JAX. Built for Project RedRocket.

Installation

pip install hydrax

If you want to use any of the extra modules, hydrax.image, hydrax.tqdm, or hydrax.pandas, you can use pip install hydrax[image,tqdm,pandas] (or pip install hydrax[all]).

Wheels are currently available for Linux x86_64 with CPython 3.10, 3.11, and 3.12.

Ensure you have JAX installed and working. If you install via pip, the latest version of JAX will be installed if it is not already, but jaxlib will not be.

From Source

Clone this repository, install Python development files and a C compiler, and run:

source path/to/your/venv/bin/activate
python -m build --wheel
pip install 'dist/hydrax-<...>.whl'

Documentation

Read the online documentation for the latest version at https://redhottensors.github.io/hydrax/.

For local HTML documentation, run make html in /sphinx and browse the generated Sphinx documentation in _build/html/index.html. You will need pip install furo, which should also install Sphinx.

Usage

from hydrax import Dataloader, DataGroup, TrainingBatch, ValidationBatch

def my_loader(data, arrays, seed):
    # load data from data source into arrays, optionally augmenting using 'seed'.
    # if 'seed' is None this is a data from a validation batch
    # return any additional data for the batch

if __name__ == "main":
    my_data = ...
    array_defs = {
        "array_name": ((dim, ...), numpy_dtype, jax_dtype),
        ...
    }

    train = DataGroup(batch_size, my_data[1000:], loader_arrays=array_defs)
    valid = DataGroup(batch_size, my_data[:1000], loader_arrays=array_defs)

    dataloader = Dataloader(
        my_loader,
        train,
        validation = ("epoch", 1, valid), # run validation after every epoch
        end_at = ("epoch", 5)             # run 5 epochs in total
    )

    with dataloader: # a with block is required
        # consider using hydrax.tqdm.tbatches instead of a vanilla for loop here
        for batch in dataloader:
            if isinstance(batch, TrainingBatch):
                run_training_batch(batch)
            elif isinstance(batch, ValidationBatch):
                run_validation_batch(batch)

            del batch # important, release batch before waiting for next one or cleaning up

# with hydrax.tqdm.tbatches
    from hydrax.tdqm import tbatches

    for batch in tbatches(dataloader, report_interval=1000): # tbatches includes a with block for the dataloader
        ...
        del batch # important, see above

Deadlocks / Stalls

If you are experiencing deadlocks as a result of retaining batch or array references between iterations, consider using debug_batch_references or gc.get_referrers to find out what's holding on to your batches, though do keep in mind that JAX dispatch will retain references while running ahead. You can check your work by running the Dataloader with depth = 1, which will immediately deadlock if the first batch is not properly released.

Batch Structure

In Hydrax, a single Dataloader is usually responsible for producing both your training and validation batches, in order to conserve resources and ensure perfectly smooth loading throughout.

Each batch produced by the Dataloader is either a TrainingBatch instance or a ValidationBatch instance, which both inherit the common functionality of Batch. (You can click any of the preceding links to view the online documentation.)

The most important properties of a Batch are:

  • arrays -- { 'array_name': jax.Array, ... }, corresponding to each array defined by the source DataGroup. The first dimension of the array is the batch size.
  • additional -- { 'key': [item_0_value, ...] }, corresponding to any additional data returned by your loader function. Each list's len is the batch size. If no corresponding item was returned, the element is None. Use get_additional(key[, index]) if your loader sometimes omits returning certain keys.
  • data -- A proxy type to the original data descriptors for each item, with length equal to the batch size.

As mentioned above, remember to release any references to a batch or its arrays as soon as you're done with them.

Loader Processes

Read the documentation for loader_func carefully. If you receive a warning from Hydrax about your loader, you should fix your code. Failure to do this could result in your batch data changing out from underneath you, leading to significant training issues such as NaNs.

Do not attempt to construct a Dataloader inside a loader process. Ensure your training code is guarded with if __name__ == '__main__':, or is otherwise prevented from running. As a last resort, you can check hydrax.is_worker and bail.

KeyboardInterrupt (Ctrl+C / SIGINT)

The Dataloader installs a handler for KeyboardInterrupt (Ctrl+C / SIGINT) which stops the flow of batches as soon as possible. After the dataloader has completed, you can check if this occurred by reading its interrupted property. You may want to save a checkpoint along with the numbers of the current epoch and batch, so that you can resume from where you left off with start_at.

If you send a second KeyboardInterrupt, Hydrax will raise a KeyboardInterrupt at the beginning of the next batch. This exception may cause you to lose progress unless you or a framework takes care to save a checkpoint in response.

If you send a third KeyboardInterrupt, the Python interpreter is immediately stopped and control is returned to you. You will lose all progress since the last checkpoint.

Compatibility

A convienient wrapper with tqdm progress bars is provided in hydrax.tqdm. The corresponding extra is tqdm.

ICC-profile-aware 8bbp image loading with Pillow is provided in hydrax.image, and support is included for Oklab as well. The corresponding extra is image.

Compatibility for Pandas datasets is provided by hydrax.pandas. The corresponding extra is pandas.

License

Hydrax is available under the terms of the Mozilla Public License, version 2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

hydrax-0.2.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl (56.4 kB view hashes)

Uploaded CPython 3.12 manylinux: glibc 2.5+ x86-64

hydrax-0.2.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl (55.7 kB view hashes)

Uploaded CPython 3.11 manylinux: glibc 2.5+ x86-64

hydrax-0.2.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl (55.2 kB view hashes)

Uploaded CPython 3.10 manylinux: glibc 2.5+ x86-64

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page