Skip to main content

No project description provided

Project description

Refx

A reference system for jax

What is Refx?

Refx defines a simple reference system that allows you to create DAGs on top of JAX PyTrees. It enables some key features:

  • Shared state
  • Tractable mutability
  • Semantic partitioning

Refx is intended to be a low-level library that can be used as a building block within other JAX libraries.

Why Refx?

Functional systems like flax and haiku are powerful but add a lot of complexity that is often transfered to the user. On the other hand, pytree-based systems like equinox are simpler but lack the ability to share parameters and modules.

Refx aims to create a system that can be used to build neural networks libraries that has the simplicity of pytree-based systems while also having the power of functional systems.

Installation

pip install refx

Getting Started

Refx's main data structure is the Ref class. It is a wrapper around a value that can be used as leaves in a pytree. It also has a value attribute that can be used to access and mutate the value.

import jax
import refx

r1 = refx.Ref(1)
r2 = refx.Ref(2)

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

pytree['a'][0].value = 10

assert pytree['a'][1].value == 10

Ref is not a pytree node, therefore you cannot pass pytrees containing Refs to JAX functions. To interact with JAX, refx provides the following functions:

  • deref: converts a pytree of references to a pytree of values and indexes.
  • reref: converts a pytree of values and indexes to a pytree of references.

deref must be called before crossing a JAX boundary and reref must be called after crossing a JAX boundary.

pytree = refx.deref(pytree)

@jax.jit
def f(pytree):
    pytree = refx.reref(pytree)
    # perform some computation / update the references
    pytree['a'][2].value = 50
    return jax.deref(pytree)

pytree = f(pytree)
pytree = refx.reref(pytree)

assert pytree['b'].value == 50

As you see in the is example, we've effectively implemented shared state and tracktable mutability with pure pytrees.

Trace-level awareness

In JAX, unconstrained mutability can lead to tracer leakage. To prevent this, refx only allows mutating references from the same trace level they were created in.

r = refx.Ref(1)

@jax.jit
def f():
    # ValueError: Cannot mutate ref from different trace level
    r.value = jnp.array(1.0)
    ...

Partitioning

Each reference has a collection: Hashable attribute that can be used to partition references into different groups. refx provides the tree_partition and to partition a pytree based a predicate function.

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

(params, rest), treedef = refx.tree_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}
assert rest == {
    ('a', '0'): refx.NOTHING,
    ('a', '1'): refx.NOTHING,
    ('a', '2'): r2,
    ('b',): r2,
}

You can use more partitioning functions to partition a pytree into multiple groups, tree_partition will always return one more partition than the number of functions passed to it. The last partition (rest) will contain all the remaining elements of the pytree. tree_partition also returns a treedef that can be used to reconstruct the pytree by using the merge_partitions function:

pytree = refx.merge_partitions(partitions, treedef)

If you only need a single partition, you can use the get_partition function:

r1 = refx.Ref(1, collection="params")
r2 = refx.Ref(2, collection="batch_stats")

pytree = {
    'a': [r1, r1, r2],
    'b': r2
}

params = refx.get_partition(
    pytree, lambda x: isinstance(x, refx.Ref) and x.collection == "params")

assert params == {
    ('a', '0'): r1,
    ('a', '1'): r1,
    ('a', '2'): refx.NOTHING,
    ('b',): refx.NOTHING
}

Updating references

refx provides the update_refs function to update references in a pytree. It updates a target pytree with references reference with the values from a source pytree. The source pytree can be either a pytree of references or a pytree of values. As an example, here is how you can use update_refs to perform gradient descent on a pytree of references:

def by_collection(c):
    return lambda x: isinstance(x, refx.Ref) and x.collection == c

(params, rest), treedef = refx.tree_partition(
    refx.deref(pytree), by_collection("params"))

def loss(params):
    pytree = refx.merge_partitions((params, rest), treedef)
    ...

grads = jax.grad(loss)(params)

# gradient descent
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads)

refx.update_refs(
    refx.get_partition(pytree, by_collection("params")), params)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

refx-0.0.3.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

refx-0.0.3-py3-none-any.whl (10.3 kB view details)

Uploaded Python 3

File details

Details for the file refx-0.0.3.tar.gz.

File metadata

  • Download URL: refx-0.0.3.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.8.10 Linux/5.13.0-1027-gcp

File hashes

Hashes for refx-0.0.3.tar.gz
Algorithm Hash digest
SHA256 1598eb96877deb5aba42a562f79b1b24c245cacfd6645b7ff64b2a00bb52081e
MD5 f3b9a1af026e4475c28d94772976d160
BLAKE2b-256 0ae19e1a64fbe01bf948422c6998032317fe71912bcbae987b337b590f5e2954

See more details on using hashes here.

File details

Details for the file refx-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: refx-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 10.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.8.10 Linux/5.13.0-1027-gcp

File hashes

Hashes for refx-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7f1c5951616889425f01350ac74a8b3f11d180465b8e96a22b8f2612ab40a61c
MD5 8868bba6e91013bcd98ab82d44dafa67
BLAKE2b-256 02852ac346d5038951109af2ffc0506ae9fc54ff29266d9572859b6aac22feef

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page