Skip to main content

DrJAX - Scalable and Differentiable MapReduce Primitives in JAX.

Project description

DrJAX - Differentiable MapReduce Primitives in JAX

DrJAX is a library designed to embed a MapReduce programming model into JAX. DrJAX has multiple objectives.

  1. Create a simple JAX-based authoring surface for MapReduce computations.
  2. Leverage JAX's sharding mechanisms to enable highly optimized execution of MapReduce computations, especially in large-scale datacenter settings.
  3. Full differentiability of DrJAX computations, including differentiating through communication primitives like broadcasts and reductions.

DrJAX is designed to make it easy to author and execute parallel computations in the datacenter. DrJAX is tailored towards large-scale parallel and distributed computations, including computations involving larger models, and ensuring that they can be run efficiently. DrJAX embeds primitives like those defined by TensorFlow Federated using the mapping capabilities and primitive extensions of JAX.

System design

For details on DrJAX's system design, check out our paper.

Citation

To cite this repository, please use the following BibTeX citation:

@inproceedings{rush2024drjax,
  title={DrJAX: Scalable and Differentiable MapReduce Primitives in JAX},
  author={Rush, J Keith and Charles, Zachary and Garrett, Zachary and Augenstein, Sean and Mitchell, Nicole Elyse},
  booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ ICML 2024)}
}

Disclaimers

This is not an officially supported Google product.

If you're interested in learning more about responsible AI practices, please see Google AI's Responsible AI Practices.

Dataset Grouper is Apache 2.0 licensed. See the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

drjax-0.1.2-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file drjax-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: drjax-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 22.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for drjax-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 489bdcc8a5571fbd5d35af4ab679e1eb2aba0bf65a4aab25a4529c1dc97ad11a
MD5 3fe2ac5640f53f460de6b8d4695c618d
BLAKE2b-256 568148f1a63f36c6aaab88d38da36a88a4364be4d4cdf19048a5aa1538c58c6f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page