Skip to main content

DrJAX - Scalable and Differentiable MapReduce Primitives in JAX.

Project description

DrJAX - Differentiable MapReduce Primitives in JAX

DrJAX is a library designed to embed a MapReduce programming model into JAX. DrJAX has multiple objectives.

  1. Create a simple JAX-based authoring surface for MapReduce computations.
  2. Leverage JAX's sharding mechanisms to enable highly optimized execution of MapReduce computations, especially in large-scale datacenter settings.
  3. Full differentiability of DrJAX computations, including differentiating through communication primitives like broadcasts and reductions.

DrJAX is designed to make it easy to author and execute parallel computations in the datacenter. DrJAX is tailored towards large-scale parallel and distributed computations, including computations involving larger models, and ensuring that they can be run efficiently. DrJAX embeds primitives like those defined by TensorFlow Federated using the mapping capabilities and primitive extensions of JAX.

System design

For details on DrJAX's system design, check out our paper.

Citation

To cite this repository, please use the following BibTeX citation:

@inproceedings{rush2024drjax,
  title={DrJAX: Scalable and Differentiable MapReduce Primitives in JAX},
  author={Rush, J Keith and Charles, Zachary and Garrett, Zachary and Augenstein, Sean and Mitchell, Nicole Elyse},
  booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ ICML 2024)}
}

Disclaimers

This is not an officially supported Google product.

If you're interested in learning more about responsible AI practices, please see Google AI's Responsible AI Practices.

Dataset Grouper is Apache 2.0 licensed. See the LICENSE file.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

drjax-0.1.0-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file drjax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: drjax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.9

File hashes

Hashes for drjax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f949036e64b367376384033d5adbfbc1b4dec5e7fa7d194b41b0d15a3bcf5ddb
MD5 ac41b422a90de595e7745a6d1e64a319
BLAKE2b-256 f450ad976b0a756a335975a464cf3a0cd4967266be883875cb47ff1f163e3257

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page