Skip to main content

Basic tools and helpers for Jax

Project description

JaxHelper

Basic tools and helpers for Jax

Getting Started

Installation

python3 -m pip install jaxhelper

Explanation

This repository contains basic helper functions I use every day.
Here are some highlights:

  • remat: function decorator to rematerialize ("activation checkpointing") hidden states during backward pass
  • softmax:
    • exp in fp32 and matmul in bf16 (-> improved convergence and speed)
    • fewer stored intermediates yet faster gradient
  • attention:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxhelper-0.0.5.tar.gz (3.6 kB view hashes)

Uploaded Source

Built Distribution

jaxhelper-0.0.5-py3-none-any.whl (3.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page