DrJAX - Scalable and Differentiable MapReduce Primitives in JAX.
Project description
DrJAX - Differentiable MapReduce Primitives in JAX
DrJAX is a library designed to embed a MapReduce programming model into JAX. DrJAX has multiple objectives.
- Create a simple JAX-based authoring surface for MapReduce computations.
- Leverage JAX's sharding mechanisms to enable highly optimized execution of MapReduce computations, especially in large-scale datacenter settings.
- Full differentiability of DrJAX computations, including differentiating through communication primitives like broadcasts and reductions.
DrJAX is designed to make it easy to author and execute parallel computations in the datacenter. DrJAX is tailored towards large-scale parallel and distributed computations, including computations involving larger models, and ensuring that they can be run efficiently. DrJAX embeds primitives like those defined by TensorFlow Federated using the mapping capabilities and primitive extensions of JAX.
System design
For details on DrJAX's system design, check out our paper.
Citation
To cite this repository, please use the following BibTeX citation:
@inproceedings{rush2024drjax,
title={DrJAX: Scalable and Differentiable MapReduce Primitives in JAX},
author={Rush, J Keith and Charles, Zachary and Garrett, Zachary and Augenstein, Sean and Mitchell, Nicole Elyse},
booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ ICML 2024)}
}
Disclaimers
This is not an officially supported Google product.
If you're interested in learning more about responsible AI practices, please see Google AI's Responsible AI Practices.
Dataset Grouper is Apache 2.0 licensed. See the LICENSE file.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file drjax-0.1.4-py3-none-any.whl.
File metadata
- Download URL: drjax-0.1.4-py3-none-any.whl
- Upload date:
- Size: 23.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
532852e4d7e8f885488d487b693f38f9757e96489ca52c9cce010a284aac3e63
|
|
| MD5 |
62bf8afde412265525880e299d5480ba
|
|
| BLAKE2b-256 |
ecfe3b104a57a1ff4e0592193f4b3c03b280ff8a1c0fa447335928b6f6668e2c
|