Skip to main content

DrJAX - Scalable and Differentiable MapReduce Primitives in JAX.

Project description

DrJAX - Differentiable MapReduce Primitives in JAX

DrJAX is a library designed to embed a MapReduce programming model into JAX. DrJAX has multiple objectives.

  1. Create a simple JAX-based authoring surface for MapReduce computations.
  2. Leverage JAX's sharding mechanisms to enable highly optimized execution of MapReduce computations, especially in large-scale datacenter settings.
  3. Full differentiability of DrJAX computations, including differentiating through communication primitives like broadcasts and reductions.

DrJAX is designed to make it easy to author and execute parallel computations in the datacenter. DrJAX is tailored towards large-scale parallel and distributed computations, including computations involving larger models, and ensuring that they can be run efficiently. DrJAX embeds primitives like those defined by TensorFlow Federated using the mapping capabilities and primitive extensions of JAX.

System design

For details on DrJAX's system design, check out our paper.

Citation

To cite this repository, please use the following BibTeX citation:

@inproceedings{rush2024drjax,
  title={DrJAX: Scalable and Differentiable MapReduce Primitives in JAX},
  author={Rush, J Keith and Charles, Zachary and Garrett, Zachary and Augenstein, Sean and Mitchell, Nicole Elyse},
  booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ ICML 2024)}
}

Disclaimers

This is not an officially supported Google product.

If you're interested in learning more about responsible AI practices, please see Google AI's Responsible AI Practices.

Dataset Grouper is Apache 2.0 licensed. See the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

drjax-0.1.1-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file drjax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: drjax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 22.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for drjax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b39662b5ccd08760977ab436a9a35450a9abdbb740e13d552354b697f9762793
MD5 5faffb1ad5b31609f2005bceec2e2419
BLAKE2b-256 eafe68074e9cff5d2a9b24a8511723d7a30111f71d86fd58577c8807401db13b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page