Skip to main content

pytorch-like ML with Jax

Project description

Brachy

A (increasingly less) simple neural network library on top of JAX.

Better JIT wrapper

First, it is very annoying that jax.jit cannot handle dictionaries as static arguments, or arguments that are pytree where some values are static and some can be traced, or functions that return static values. So, we provide a general wrapper in structure_util.improved_static that takes care of this by automatically separating out traceable and non-traceable components of arguments before passing to jit. It also can handle non-jaxtypes in return values. However, be careful with these: we assume that any non-jaxtype in a return value must be a fixed function of the static arguments and the shape of the traced arguments (i.e. their values do not change unless the function needs to be re-traced).

import jax
from jax import numpy as jnp
from brachy import structure_util as su

jit  = su.jit
# su.jit is an alias for su.improved_static(jax.jit).

@jit
def foo(x,y):
    if x['q'] == 'go ahead!':
        return {'a': x['a'], 'b': y['b']}
    else:
        return {'a': 2*y['a'], 'b': y['b']}

x = {
    'q': 'stop',
    'a': jnp.ones(3)
}
y = {
    'a': jnp.ones(5),
    'b': ['hello', 'friend']
}

z = foo(x,y) 
## z should be:
# {'a': jnp.array([2,2,2,2,2]), 'b': ['hello', 'friend']}

Further, this wrapper also will automatically extract and make static the static components of a structure tree (described below). That is, if a structure tree has an otherwise traceable value under the 'aux' key (e.g. a False for configuration or similar), then it will not be traced.

This wrapper su.improved_static can also be used to impart similar behavior to other JAX primitives (e.g. xmap).

Overview

HAX tries to keep your code as close to the functional spirit of JAX as possible while also facilitating easy portability from pytorch.

In pytorch, a module is an instance of a class that stores various model parameters (i.e. network weights) as class members. These are relatively straightforward to code up, but have two important drawbacks. First, this style of module does not play nice with JAX's functional programming style. This means that it is difficult to implement more complicated ideas such as meta-gradient descent (i.e. differentiating with respect to hyperparameters). Second, as models grow in size and complexity, it will likely become more and more important to be able to "mix and match" different components of pre-trained modules in an easy way. Right now, to extract the output of some intermediate layer or to add a new layer somewhere in the module computation requires a careful inspection of the source code and often some extra work to transfer pretrained weights to the new architecture. However, this is not really necessary: model architectures are usually relatively straightforwardly described as simple trees. Hax exploits this to solve both problems by providing utilities to directly compute with architectures described in a tree form.

A Hax module is a pair consisting of a "structure tree" and a "global config". Both of these are python dictionaries. The global config should probably be even a a JSON object of config values (e.g. {'training_mode': True}). The structure tree is a tree that contains both model weights and functions describing how to apply these weights. We could have tried to organize the structure tree as a python class. However, we wanted to make the structure trees as hackable as possible. Wrapping them in some complicated class mechanism in order to provide some ease of use in common cases might make this more difficult. That said, Hax does still provide a class StateOrganizer that can be used to convert a structure tree into a class that behaves very similarly to a pytorch module, which is useful for building structure trees.

Formally, a Hax structure tree S is a dict whose keys are "params", "buffers", "aux", "apply", and "submodules". The value S["submodules"] is either a dict whose values are themselves structure trees (i.e. S["submodules"] specified the children of S in the tree). The values S["params"] and S["buffers"] are both dicts whose values are JAX types. By a JAX type, we mean a value that is a valid argument to a traced JAX functions (e.g. a pytree where all leaves are JAX arrays). That is, the function:

@jax.jit
def identity(x):
    return jax.tree_utils.tree_map(lambda a:a, x)

will run without error on any JAX type.

The value S["apply"] is a function with signature:

def apply(
    structure_tree: Hax.structure_tree,
    global_config: dict,
    *args,
    **kwargs) -> Hax.structure_tree, Any

Hax.structure_tree is simply an alias for a dict, so any function that takes a dict as the first two arguments and returns a dict is acceptable. The additional arguments to apply will be implementation specific. The first return value is the "output" of the module, and the second return value is an updated version of the input argument structure_tree. For example, a linear layer might be implemented as follows:

def linear_apply(S: Hax.structure_tree, global_config: dict, x: Array) -> Array, Hax.structure_tree:
    weight = S["params"]["weight"]
    bias = S["params"]["bias"]

    y = x @ weight + bias

    return S, y

In this case, we did not need to change the input structure tree. However, layers that require state, randomization, or different behaviors in the train or eval setting require more delicate construction:

def dropout_apply(S: Hax.structure_tree, global_config: dict, x: Array) -> Array, Hax.structure_tree:
    if not global_config["is_training"]:
        return S, x

    rng = S["buffers"]["rng"]
    rng, subkey = jax.random.split(rng)

    p = S["buffers"]["p"]
    y = x * jax.random.bernoulli(subkey, p, shape=x.shape)/p

    S["buffers"]["rng"] = rng

    return S, y

Note that it is strongly advised NOT to change the "apply" or "aux" values of the input S inside these apply functions as this will cause retracing when jitting. Instead, these values are meant to be edited as part of an overall surgery on the model architecture.

Structure Tree Terminology

Technically, many of the functions in this package do not require a structure tree to have all the keys "params", "buffers", "aux", "apply": only the "submodules" key is really needed. Given a structure tree tree, we say that tree is a leaf if tree["submodules"] = {}. Further, we say that tree is a node with path [k1, k2, k3] if there is a root tree root such that tree = root["submodules"][k1]["submodules"][k2]["submodules"][k3]. In general, the path of tree["submodules"][k] is P + [k] where P is the path of tree.

Structure Tree Utils

brachy.structure_tree_util contains the core functions that power converting structure trees into the forward pass function for a module and back. Key utilities include:

  • structure_tree_map(func: Callable, *trees: List[dict], path={}) -> Union[dict, Tuple[dict,...]]. The first argument is a function func(*nodes, path) that outputs a leaf node (or a tuple of leaf nodes).The second argument trees must be either a single structure tree or a list of structure trees. The output will be a structure tree such that for each unique path P in any of the trees in trees, the output tree will have a node with path P that is the output of func with first argument nodes being the list [subtree of tree at path P for tree in trees] and path=P. If func returns multiple trees, then structure_tree_map will output the corresponding multiple trees.
  • StateOrganizer: this class does a lot of the heavy lifting to make defining new structure trees similar to the experience of defining a module in pytorch. Eventually, one can call organizer.create_module() to obtain a tuple tree, global_config. When building the tree, if you assign a new attribute to a StateOrganizer object with a tuple via 'organizer.name = (subtree, sub_global_config), then the tree returned by organizer.create_module() will have subtree as the value ["submodules"][name]. Also, global_config will be merged with sub_global_config (value in sub_global_config do not override old values). See the examples directory to see how to use StateOrganizer objects.
  • apply_tree(tree: dict, global_config: dict, *args, **kwargs). This function is a shorthand for tree['apply'](tree, global_config, *args, **kwargs).
  • bind_module(tree: dict, global_config: dict) -> dict, Callable. This function is mostly unecessary given the updated brachy.structure_util.jit functionality described earlier. It takes as input a structure tree and a global config and returns a state dictionary and an apply function. The state dictionary is just the original structure tree with all but the "params", "buffers", and "submodules" keys removed. This represents the current mutable state of the module. The apply function will apply the tree: it takes a state dictionary and whatever inputs the module requires and returns both an updated state dictionary and all the ouptuts of the module. The returned apply function from bind_module can be Jitted as it captures the unhashable global_config dictionary in a closure. To change the global config dictionary, use apply.bind_global_config(new_global_config). To recover a full structure tree, use tree, global_config = unbind_module(state, apply).

Random number generator utils

The file rng_util.py contains a context manager that makes it easier to pass JAX prngkeys down through a tree of functions without having to write a ton of rng, subkey = jax.random.split(rng) all over the place. See the comments at the top of the file or the usages in the resnet example or the nn.py file for more info.

This utility can be combined with the StateOrganizer via the decorators organized_init_with_rng and organized_apply defined in structure_util.py. See the language modeling example for these decorators in use.

Installing

From pip

You can now pip install brachy! However, this will explicitly NOT install jax as the installation process for jax seems to differ depending on GPU vs CPU. You should install the appropriate jax version

BU SCC setup instructions

You need python3, and jax (pytorch useful for dataloaders, or running tests). Currently there seems to be some issue preventing simultaneous loading of jax, pytorch and tensorflow. However, we probably don't need tensorflow so it is not a huge problem.

module load python3 pytorch cuda/11.6 jax/0.4.6

You should probably also setup a virtual environment: python -m venv brachyenv to create, source brachyenv/bin/activate to activate, deactive to leave the environment.

Some of the example require additional packages listed in the requirements.txt file. You can pip install --upgrade pip and then pip install -r requirements.txt to get them to run. Or just run an example and then do pip install one by one as you get "ModuleNotFoundError

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

brachy-0.0.2.tar.gz (53.6 kB view details)

Uploaded Source

Built Distribution

brachy-0.0.2-py3-none-any.whl (32.3 kB view details)

Uploaded Python 3

File details

Details for the file brachy-0.0.2.tar.gz.

File metadata

  • Download URL: brachy-0.0.2.tar.gz
  • Upload date:
  • Size: 53.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for brachy-0.0.2.tar.gz
Algorithm Hash digest
SHA256 5ba2dc126b851c3533fdd45fe8552daa2ced797cc91dcc23b66f66c77ab92274
MD5 a005d9e188e3698747ba74ed66adf76b
BLAKE2b-256 3e03f6f02d795b7b9845be44fb665a209facf784e91e96c7642e949fd15e4d67

See more details on using hashes here.

File details

Details for the file brachy-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: brachy-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 32.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.10

File hashes

Hashes for brachy-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7e24b83febea50e1b6a1d6dd937a26227b3f05020b3ea218903489dd3b5bb826
MD5 c944751921861bd6763afb5270c38a70
BLAKE2b-256 21be524ef56abf90a3512a60b0901c998df377add67b99cc592702ba53f1b65d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page