Skip to main content

JAXline is a distributed JAX training framework.

Project description

JAXline - Experiment framework for JAX

What is JAXline

JAXline is a distributed JAX training and evaluation framework. It is designed to be forked, covering only the most general aspects of experiment boilerplate. This ensures that it can serve as an effective starting point for a wide variety of use cases.

Many users will only need to fork the experiment.py file and rely on JAXline for everything else. Other users with more custom requirements will want to (and are encouraged to) fork other components of JAXline too, depending on their particular use case.

Contents

Quickstart

Installation

JAXline is written in pure Python, but depends on C++ code via JAX and TensorFlow (the latter is used for writing summaries).

Because JAX / TensorFlow installation is different depending on your CUDA version, JAXline does not list JAX or TensorFlow as a dependencies in requirements.txt.

First, follow the instructions to install JAX and TensorFlow respectively with the relevant accelerator support.

Then, install JAXline using pip:

$ pip install git+https://github.com/deepmind/jaxline

Building your own experiment

  1. Create an experiment.py file and inside it define an Experiment class that inherits from experiment.AbstractExperiment.

  2. Implement the methods required by AbstractExperiment in your own Experiment class (i.e. the abstractmethods). Optionally override the default implementations of AbstractExperiment's other methods.

  3. Define a config, either in experiment.py or elsewhere, defining any settings that you do not wish to inherit from base_config. At the very least this will include config.experiment_kwargs to define the config required by your Experiment. Make sure this config object is included in the flags accessible to experiment.py.

  4. Add the following lines to the bottom of your experiment.py to ensure that your Experiment object is correctly passed through to platform.py:

    if __name__ == '__main__':
      flags.mark_flag_as_required('config')
      platform.main(Experiment, sys.argv[1:])
    
  5. Run your experiment.py.

Checkpointing

So far this version of JAXline only supports in-memory checkpointing, as handled by our InMemoryCheckpointer It allows you to save in memory multiple separate checkpoint series in your train and eval jobs (see below).

The user is expected to override the CHECKPOINT_ATTRS and NON_BROADCAST_CHECKPOINT_ATTRS dicts (at least one of these) in order to map checkpointable attributes of their own Experiment class to names they wish them to be stored under in the checkpoint. CHECKPOINT_ATTRS specifies jax DeviceArrays for which jaxline should only take the first slice (corresponding to device 0) for checkpointing. NON_BROADCAST_CHECKPOINT_ATTRS specifies any other picklable object that jaxline should checkpoint whole.

You can specify the frequency with which to save checkpoints, as well as whether to checkpoint based on step or seconds, by setting the save_checkpoint_interval and interval_type config flags here.

config.max_checkpoints_to_keep can be used to specify the maximum number of checkpoints to keep. By default this is set to 5.

By setting config.best_model_eval_metric, you can specify which value in the scalars dictionary returned by your evaluate function to use as a 'fitness score'. JAXline will then save a separate series of checkpoints corresponding to steps at which the fitness score is better than previously seen. Depending on whether you are maximizing or minimizing the eval metric, set config.best_model_eval_metric_higher_is_better to True or False.

Logging

So far this version of JAXline only supports logging to Tensorboard via our TensorBoardLogger

The user is expected to return a dictionary of scalars from their step and evaluate methods, and TensorBoardLogger.write_scalars will periodically write these scalars to TensorBoard.

All logging will happen asynchronously to the main thread so as not to interrupt the training loop.

You can specify the frequency with which to log, as well as whether to log by step or by seconds, by setting the log_train_data_interval and interval_type config flags here. If config.log_all_train_data is set to True (False by default) JAXline will cache the scalars from intermediate steps and log them all at once at the end of the period.

JAXline passes the TensorBoardLogger instance through to the step and evaluate methods to allow the user to perform additional logging inside their Experiment class if they so wish. A particular use case for this is if you want to write images, which can be achieved via ExperimentWriter.write_images.

Launching

So far this version of JAXline does not support launching remotely.

Distribution strategy

JAX makes it super simple to distribute your jobs across multiple hosts and cores. As such, JAXline leaves it up to the user to implement distributed training and evaluation.

Essentially, by decorating a function with jax.pmap you tell JAX to slice the inputs along the first dimension and then run the function in parallel on each input slice, across all available local devices (or a subset thereof). In other words, jax.pmap invokes the single-program multiple-data (SPMD) paradigm. Then by using jax.lax collective communications operations from within your pmapped function, you can tell JAX to communicate results between all devices on all hosts. For example, you may want to use jax.lax.psum to sum up the gradients across all devices on all hosts, and return the result to each device (an all-reduce).

JAX will then automatically detect which devices are available on each host allowing jax.pmap and jax.lax to work their magic.

One very important thing to bear in mind is that each time you call jax.pmap, a separate TPU program will be compiled for the computation it wraps. Therefore you do not want to be doing this regularly! In particular, for a standard ML experiment you will want to call jax.pmap once to wrap your parameter update function, and then you call this wrapped function on each step, rather than calling jax.pmap on each step, which will kill your performance! This is a very common mistake for new JAX starters. Luckily it has quite an extreme downside so should be easily noticeable. In JAXline we actually call jax.pmap once more in next_device_state to wrap our function to update device state between steps, so end up with 2 TPU programs rather than just 1 (but this adds negligible overhead).

Random number handling

Random numbers in JAX might seem a bit unfamiliar to users coming from ordinary numpy and Tensorflow. In these languages we have global stateful PRNGs. Every time you call a random op it updates the PRNGs global state. However, stateful PRNGs in JAX would be incompatible with JAX's functional design semantics, leading to problems with reproducibility and parallelizability. JAX introduces stateless PRNGs to avoid these issues. The downside of this is that the user needs to thread random state through their program, splitting a new PRNG off from the old one every time they want to draw a new random number. This can be quite onerous, especially in a distributed setting, where you may have independent PRNGs on each device.

In JAXline we take care of this for you. On each step, in next_device_state, we split a new PRNG from the old one, and optionally specialize it to the host and/or device based on the random_mode_train config value you specify. We then pass this new PRNG through to your step function to use on that particular step. At evaluation time, we pass a fresh PRNG to your evaluate method, initialized according to the random_mode_eval config value you specify. This PRNG will be the same on each call to evaluate (as normally you want your evaluation to be deterministic). If you want different random behaviour on each call, a simple solution would be to fold in the global_step i.e. jax.random.fold_in(rng, global_step) at the top of your evaluate method.

Of course you are free to completely ignore the PRNGs we pass through to your step and evaluate methods and handle random numbers in your own way, should you have different requirements.

Debugging

Post mortem debugging

By setting the flag --jaxline_post_mortem (defined here) on the command-line, tasks will pause on exceptions (except SystemExit and KeyboardInterrupt) and enter post-mortem debugging using pdb. Paused tasks will hang until you attach a debugger.

Disabling pmap and jit

By setting the flag --jaxline_disable_pmap_jit on the command-line, all pmaps and jits will be disabled, making it easier to inspect and trace code in a debugger.

Citing Jaxline

Please use this reference.

Contributing

Thank you for your interest in JAXline. The primary goal of open-sourcing JAXline was to allow us to open-source our research more easily. Unfortunately, we are not currently able to accept pull requests from external contributors, though we hope to do so in future. Please feel free to open GitHub issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxline-0.0.8.tar.gz (35.8 kB view details)

Uploaded Source

File details

Details for the file jaxline-0.0.8.tar.gz.

File metadata

  • Download URL: jaxline-0.0.8.tar.gz
  • Upload date:
  • Size: 35.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for jaxline-0.0.8.tar.gz
Algorithm Hash digest
SHA256 475a1ab56cb556127fa99df0ab23cbccc0d345c6b1a02707f4414932cef50f76
MD5 5473dcc2fb723bc5711f8cfbdc0d78b7
BLAKE2b-256 ad55a9b8c9293f7323dbe2eda32ac1bea449f7a7e9c80aa0d7221de220b84fae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page