Skip to main content

Named Tensors for Legible Deep Learning in JAX

Project description

Haliax

Build Status Documentation Status License PyPI

Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.
— Patrick Rothfuss, The Name of the Wind

Haliax is a JAX library for building neural networks with named tensors, in the tradition of Alexander Rush's Tensor Considered Harmful. Named tensors improve the legibility and compositionality of tensor programs by using named axes instead of positional indices as typically used in NumPy, PyTorch, etc.

Despite the focus on legibility, Haliax is also fast, typically about as fast as "pure" JAX code. Haliax is also built to be scalable: it can support Fully-Sharded Data Parallelism (FSDP) and Tensor Parallelism with just a few lines of code. Haliax powers Levanter, our companion library for training large language models and other foundation models, with scale proven up to 20B parameters and up to a TPU v3-256 pod slice.

Example: Attention

Here's a minimal attention module implementation in Haliax. For a more detailed introduction, please see the Haliax tutorial. (We use the excellent Equinox library for its module system and tree transformations.)

import equinox as eqx
import jax
import jax.numpy as jnp
import haliax as hax
import haliax.nn as hnn

Pos = hax.Axis("position", 1024)  # sequence length
KPos = Pos.alias("key_position")
Head = hax.Axis("head", 8)  # number of attention heads
Key = hax.Axis("key", 64)  # key size
Embed = hax.Axis("embed", 512)  # embedding size

# alternatively:
#Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512)


def attention_scores(Key, KPos, query, key, mask):
    # how similar is each query to each key
    scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size)

    if mask is not None:
        scores -= 1E9 * (1.0 - mask)

    # convert to probabilities
    scores = haliax.nn.softmax(scores, KPos)
    return scores


def attention(Key, KPos, query, key, value, mask):
    scores = attention_scores(Key, KPos, query, key, mask)
    answers = hax.dot(scores, value, axis=KPos)

    return answers


# Causal Mask means that if pos >= key_pos, then pos can attend to key_pos
causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos)


class Attention(eqx.Module):
    proj_q: hnn.Linear  # [Embed] -> [Head, Key]
    proj_k: hnn.Linear  # [Embed] -> [Head, Key]
    proj_v: hnn.Linear  # [Embed] -> [Head, Key]
    proj_answer: hnn.Linear  # output projection from [Head, Key] -> [Embed]

    @staticmethod
    def init(Embed, Head, Key, *, key):
        k_q, k_k, k_v, k_ans = jax.random.split(key, 4)
        proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q)
        proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k)
        proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v)
        proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans)
        return Attention(proj_q, proj_k, proj_v, proj_answer)

    def __call__(self, x, mask=None):
        q = self.proj_q(x)
        # Rename "position" to "key_position" for self attention
        k = self.proj_k(x).rename({"position": "key_position"})
        v = self.proj_v(x).rename({"position": "key_position"})

        answers = attention(Key, KPos, q, k, v, causal_mask)

        x = self.proj_answer(answers)
        return x

Haliax was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. You can find us in the #levanter channel on the unofficial Jax LLM Discord.

Documentation

Tutorials

These are some tutorials to get you started with Haliax. They are available as Colab notebooks:

API Reference

Haliax's API documentation is available at haliax.readthedocs.io.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information. We also have a list of good first issues to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!)

License

Haliax is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

haliax-1.4.dev321.tar.gz (675.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

haliax-1.4.dev321-py3-none-any.whl (102.2 kB view details)

Uploaded Python 3

File details

Details for the file haliax-1.4.dev321.tar.gz.

File metadata

  • Download URL: haliax-1.4.dev321.tar.gz
  • Upload date:
  • Size: 675.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for haliax-1.4.dev321.tar.gz
Algorithm Hash digest
SHA256 8c189dc0ac4382725b7d151643308158dee08335e9784e25b19e0fe92d9a9e76
MD5 5806f388e0cf518b81343ec89c7a0f3e
BLAKE2b-256 05ece47be0df1d51a973e95479f0fb997aa77675cd8a0688b671310aeb322c16

See more details on using hashes here.

File details

Details for the file haliax-1.4.dev321-py3-none-any.whl.

File metadata

  • Download URL: haliax-1.4.dev321-py3-none-any.whl
  • Upload date:
  • Size: 102.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for haliax-1.4.dev321-py3-none-any.whl
Algorithm Hash digest
SHA256 69e146a72f299afeb58605d6e443d8d6e945b76e9a2f8c7719fc6bf06b09bf7b
MD5 da0eac530c59c7251c84b0c262b5c1ab
BLAKE2b-256 ff9b05ae4718a3b04a80a438296f9cc03050a26e38d48b56c67d8c3768897abe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page