Skip to main content

No project description provided

Project description

Refx

A reference system for jax

What is Refx?

Refx defines a simple reference system that allows you to create DAGs on top of JAX PyTrees. It enables some key features:

  • Shared state
  • Tractable mutability
  • Semantic partitioning

Refx is intended to be a low-level library that can be used as a building block within other JAX libraries.

Why Refx?

Functional systems like flax and haiku are powerful but add a lot of complexity that is often transfered to the user. On the other hand, pytree-based systems like equinox are simpler but lack the ability to share parameters and modules.

Refx aims to create a system that can be used to build neural networks libraries that has the simplicity of pytree-based systems while also having the power of functional systems.

Installation

pip install refx

Getting Started

Refx's main data structure is the Ref class. It is a wrapper around a value that can be used as leaves in a pytree. It also has a value attribute that can be used to access and mutate the value.

import jax
import refx

r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

pytree['a'][0].value = 10

assert pytree['a'][1].value == 10

Ref is not a pytree node, therefore you cannot pass pytrees containing Refs to JAX functions. To interact with JAX, refx provides the following functions:

  • deref: converts a pytree of references to a pytree of values and indexes.
  • reref: converts a pytree of values and indexes to a pytree of references.

deref must be called before crossing a JAX boundary and reref must be called after crossing a JAX boundary.

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    pytree = refx.reref(pytree)
    # perform some computation / update the references
    pytree['a'][2].value = 50
    return jax.deref(pytree)

pytree = f(pytree)
pytree = refx.reref(pytree)

assert pytree['b'].value == 50

As you see in the is example, we've effectively implemented shared state and tracktable mutability with pure pytrees.

Trace-level awareness

In JAX, unconstrained mutability can lead to tracer leakage. To prevent this, refx only allows mutating references from the same trace level they were created in.

r = refx.Ref(1)

@jax.jit
def f():
    # ValueError: Cannot mutate ref from different trace level
    r.value = jnp.array(1.0)
    ...

Partitioning

Each reference has a collection: Hashable attribute that can be used to partition references into different groups. refx provides the tree_partition and to partition a pytree based a predicate function.

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

(params, rest), treedef = refx.tree_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}
assert rest == {
    ('a', '0'): refx.NOTHING,
    ('a', '1'): refx.NOTHING,
    ('a', '2'): r2,
    ('b',): r2,
}

You can use more partitioning functions to partition a pytree into multiple groups, tree_partition will always return one more partition than the number of functions passed to it. The last partition (rest) will contain all the remaining elements of the pytree. tree_partition also returns a treedef that can be used to reconstruct the pytree by using the merge_partitions function:

pytree = refx.merge_partitions(partitions, treedef)

If you only need a single partition, you can use the get_partition function:

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

params = refx.get_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}

Updating references

refx provides the update_refs function to update references in a pytree. It updates a target pytree with references reference with the values from a source pytree. The source pytree can be either a pytree of references or a pytree of values. As an example, here is how you can use update_refs to perform gradient descent on a pytree of references:

def by_collection(c):
    return lambda x: isinstance(x, refx.Ref) and x.collection == c

(params, rest), treedef = refx.tree_partition(
    refx.deref(pytree), by_collection("params"))

def loss(params):
    pytree = refx.merge_partitions((params, rest), treedef)
    ...

grads = jax.grad(loss)(params)

# gradient descent
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads)

refx.update_refs(
    refx.get_partition(pytree, by_collection("params")), params)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

refx-0.0.3.tar.gz (10.4 kB view hashes)

Uploaded Source

Built Distribution

refx-0.0.3-py3-none-any.whl (10.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page