Skip to main content

automatic mixed precision in JAX

Project description

jaxamp: automatic mixed precision in JAX


Table of Contents

Installation

pip install jaxamp

Usage

TL;DR: Like pytorch amp, but for JAX.

Replace loss_fn(model, minibatch) with jaxamp.amp(loss_fn)(model, minibatch) to run with with mixed precision. Use scaler_state = jaxamp.DynamicScalerState() and jaxamp.dynamic_scale_grad or jaxamp.dynamic_scale_value_and_grad to apply a dynamic loss scaler:

def loss(model, minibatch):
  ...

scaler_state= jaxamp.DynamicLossScaler()
amp_loss = jaxamp.amp(loss)
grad_fn = jaxamp.dynamic_scale_grad(amp_loss)
scaler_state, grad = grad_fn(model, minibatch, dynamic_scaler_state=scaler_state)

More details

Your usual training loop might look like this:

def loss_fn(model_state, minibatch):
  ...
  return loss, accuracy

def train_step(model_state, opt_state, minibatch, optimizer):

  value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  (loss, accuracy), grads = value_and_grad_fn(model_state, minibatch)

  updates, opt_state = optimizer.update(grads, opt_state, model_state)
  model_state = optax.apply_updates(model_state, updates)
  return model_state, opt_state, loss, accuracy

def train_loop(model_state, opt_state, optimizer, dataloader):
  train_step_jit = jax.jit(train_step, static_argnums=3)

  for minibatch in dataloader:
    model_state, opt_state, loss, accuracy = train_step_jit(model_state, opt_state, minibatch, optimizer)
    log_metrics(loss, accuracy)
  return model_state, opt_state

Now, you can replace this with:

def train_step(
    model_state,
    opt_state,
    minibatch,
    dynamic_scaler_state,
    optimizer):
  amp_loss_fn = jaxamp.amp(loss_fn)
  
  value_and_grad_fn = jaxamp.dynamic_scale_value_and_grad(amp_loss_fn, has_aux=True)

  dynamic_scaler_state, ((loss, accuracy), grads) = value_and_grad_fn(
    model_state,
    minibatch,
    dynamic_scaler_state=dynamic_scaler_state)

  updates, opt_state = optimizer.update(grads, opt_state, model_state)
  model_state = optax.apply_updates(model_state, updates)
  return model_state, opt_state, dynamic_scaler_state, loss, accuracy

def train_loop(model_state, opt_state, optimizer, dataloader):
  train_step_jit = jax.jit(train_step, static_argnums=3)
  dynamic_scaler_state = amp.DynamicScalerState()
  for minibatch in dataloader:
    model_state, opt_state, dynamic_scaler_state, loss, accuracy = train_step_jit(
      model_state,
      opt_state,
      minibatch,
      optimizer)
    log_metrics(loss, accuracy)
  return model_state, opt_state

It should now be faster!

More details on amp

The amp function transforms an arbitrary function into one in which some operations are performed in low precision. This precision can be controlled via the compute_dtype keyword-only argument: amp_loss_fn = amp(loss_fn, compute_dtype=jnp.float16). You can also control which operations are performed in low precision (and how) via the amp_policy keyword-only argument. This argument should take a dictionary whose keys must be either strings or jax primitives (e.g. jax.lax.add_p). The values are functions that will be called to cast arrays into relevant dtypes. These functions should have signature:

def precision_fn(
    compute_dtype: Type,
    original_dtypes: Sequence[Type],
    *invars: Sequence[Array],
    *bind_params: Dict[str, Any]) -> Sequence[Array], Dict[str, Any]:
  '''
  Args:
    compute_dtype: this is the compute_dtype provided to `amp`.
    original_dtypes: these are the dtypes that original user code expected the arguments
        to the op we are about to transform were going  to be.
    invars: the input arrays to this operation (note that these dtypes may not match
        original_dtypes because of previous casting we might have performed).
    bind_params: the "meta" parameters to the op (things like axis specifications).
  returns
    new_invars, new_bind_params: the transformed values for invars and bind_params.
  '''

For example, the function used to cast to compute_dtype is:

def use_compute_precision(
    compute_dtype: Type,
    original_dtypes: Sequence[Type],
    *invars: Sequence[Any],
    **bind_params: Dict[str, Any]
) -> (Sequence[Any], Dict[str, Any]):
    invars = cast_tree(compute_dtype, invars)
    bind_params = cast_tree(compute_dtype, bind_params)
    bind_params = dict(bind_params)
    if "preferred_element_type" in bind_params:
        bind_params["preferred_element_type"] = compute_dtype
    return invars, bind_params

amp will walk through all the ops in your function and look up each op in your amp_policy dict. If the op is present, it will apply the specified function Otherwise it will cast the inputs to their original values and apply the op unchanged. You can also provide string keys in amp_policy. In this case, if the current operation is executed inside a scope declared with jax.named_scope, we will apply the specified transformation function. If two or more active scopes match policies in amp_policy the outermost scope is used. There are two special scopes "amp_step" and "amp_default". By default these both stop any automatic mixed precision from happening inside them.

Selectively Disabling AMP

You can disable amp for a specific function (or area of code) using the context/decorator jaxamp.amp_stop:

@jaxamp.amp_stop
def high_precision_matmul(W, x):
    return jnp.dot(W, x)

def high_precision_with_context(W, x):
    # will be low precision
    y = jnp.dot(W, x)
    
    with jaxamp.amp_stop():
        # in fp32 precision
        z = jnp.dot(W, y)
    return z

More details on dynamic loss scalers

We supply a loss scaling operation via DynamicScalerState and corresponding functions dynamic_scale_grad and dynamic_scale_value_and_grad.

DynamicScalerState has the following structure:

class DynamicScalerState(NamedTuple):
    patience: jax.Array = jnp.array(2000) # number of non-inf/NaN iterations to wait before increasing the scaler
    adjust_factor: jax.Array = jnp.array(2.0) # When increasing or decreasing the scaler, multiply or divide by this factor.
    scaler: jax.Array = jnp.array(2**15, dtype=jnp.float32) # current scaler value
    count: jax.Array = jnp.array(0) # number of non-inf/NaN iterations since the scaler was last increased.

The gradient functions then have behavior like:

def dynamic_scale_value_and_grad(
    fun: Callable,
    *,
    has_aux: bool = False,
    redo_on_nan: bool = 0,
    filter=True,
    **kwargs
):
    '''
    apply dynamic scalar to the value_and_grad function.

    Args:
        fun: function to differentiate
        has_aux: same meaning as in jax.grad
        redo_on_nan: if the output is nan, we will decrease the scaler
            and recompute this many times. If the output remains nan, give up
            and return it.
        filter: if True, differentiate with equinox.filter_value_and_grad, otherwise use jax.value_and_grad

    Returns:
        grad_fn: a function that behaves like the output of jax.value_and_grad except:
            1. has an extra required keyword argument dynamic_scaler_state
            2. the return value is now a tuple (next_dynamic_scaler_state, (value, grads))
                of (next_dynamic_scaler_state, ((value, aux), grads)) if has_aux=True
    '''

More usage tips

When using optax, you may want to wrap your optimizers in optax.apply_if_finite to automatically skip NaN gradients. Alternatively, you could use the redo_on_nan option.

License

jaxamp is distributed under the terms of the Apache 2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxamp-0.0.4.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

jaxamp-0.0.4-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file jaxamp-0.0.4.tar.gz.

File metadata

  • Download URL: jaxamp-0.0.4.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for jaxamp-0.0.4.tar.gz
Algorithm Hash digest
SHA256 8601a54b1af142c6c093862edc096b707fd23cfefe476e12bec6c82bee903906
MD5 bf6f82ffefcc45370554e71eaa16a89c
BLAKE2b-256 b7a02eea27e31b3c8cf7e3ff3c7f544576a16e4c50db144196d89efe84de517d

See more details on using hashes here.

File details

Details for the file jaxamp-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: jaxamp-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 12.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for jaxamp-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 5559a7f9e5de803dfcba576399209ffc74e3a206d636475d5024365245e03947
MD5 4b72977eecbb59d108df5eb98017ef28
BLAKE2b-256 e7a817192d03a9d9afe8616caa5611aa2592cfb217a1bf8b4440a9812bca575c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page