Skip to main content

Automatically broadcast inputs by dynamically applying jax.vmap

Project description

Tests

Automatically broadcast via JAX vmap

Automatically broadcast a function that takes inputs of specific ranks to accept any batch dimensions using jax.vmap. See the documentation of JAX for more information about vmap.

This module defines a decorator function which takes as input the base ranks of the arguments of a function. The transformed function takes any broadcastable combination of batched inputs and automatically applies jax.vmap as required. One could equivalently apply vmap by hand, however the in-axes have to be chosen differently for different inputs. The decorator takes care of this automatically; if the underlying ranks are known, batch dimensions can be inferred.

Example

Consider the following function which takes numeric arguments with fixed and known ranks as input:

import jax.numpy as jnp

def foo(s, v, m):  # s - scalar, v - vector, m - matrix
    return v @ m @ v + s * v.size

If we have inputs of appropriate rank, the function can be applied without problem:

s = jnp.array(2.0)    # scalar
v = jnp.ones(3)       # vector
m = jnp.ones((3, 3))  # matrix
print(foo(s, v, m))   # prints 15.0

Assume now, however, that we have 5 matrices and 5 vectors for which we want to apply the above function as a batch. Many numpy functions can take inputs with leading batch dimensions, but here we have an issue because v @ m @ v requires m to be a matching matrix.

s = jnp.array(2.0)
v = jnp.ones((5, 3))
m = jnp.ones((5, 3, 3))
print(foo(s, v, m))  # throws TypeError

There are multiple possible ways we can solve this

  • We could try to write the function more carefully, so it can take both single and batch inputs. Or we could always require the first dimension to be a batch index (like often done in convolutional neural networks).
  • Given known inputs, we can transform our function: jax.jit(foo, (None, 0, 0)). However, the axes change depending on the inputs. If we want to expose functions that accept batched inputs to the user, we need to have some clear naming scheme (to indicate which arguments are batched).

Sometimes, the best solution is one of the above. This module provides another more flexible solution. If the ranks the function wants are known, we can derive which arguments have (leading) batch dimensions. Based on that, we can apply jax.vmap appropriately. That is exactly what the auto_vmap wrapper does. Thanks to jax.jit, after the transformed function is JIT-compiled, there is no price to pay for this extra flexibility since it only depends on the statically-known input shapes.

We can define the more flexible function as follows:

from jax_autovmap import auto_vmap

@auto_vmap(s=0, v=1, m=2)
def foo(s, v, m):
    return v @ m @ v + s * v.size

print(foo(s, v, m))  # prints [15. 15. 15. 15. 15.]

The ranks can be specified by keyword argument as above, or positionally (in this case @auto_vmap(0, 1, 2)). This does not have to be applied to all input arguments. They can either be omitted or, if ranks are given positionally, specified as None.

If the arguments are pytrees (python structures of arrays) and the rank is a single integer, all constituents (leaves) are assumed to have that rank. Alternatively, the rank can be a matching pytree, just like the in_axes in jax.vmap:

@auto_vmap({'s'=0, 'v'=1, 'm'=2})
def foo(inputs):
    return inputs['v'] @ inputs['m'] @ inputs['v']

print(foo(dict(s=s, v=v, m=m)))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_autovmap-0.1.0.tar.gz (6.2 kB view hashes)

Uploaded Source

Built Distribution

jax_autovmap-0.1.0-py3-none-any.whl (6.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page