Skip to main content

Named Tensors for Legible Deep Learning in JAX

Project description

Haliax

Build Status Documentation Status License PyPI

Though you don’t seem to be much for listening, it’s best to be careful. If you managed to catch hold of even just a piece of my name, you’d have all manner of power over me.
— Patrick Rothfuss, The Name of the Wind

Haliax is a JAX library for building neural networks with named tensors, in the tradition of Alexander Rush's Tensor Considered Harmful. Named tensors improve the legibility and compositionality of tensor programs by using named axes instead of positional indices as typically used in NumPy, PyTorch, etc.

Despite the focus on legibility, Haliax is also fast, typically about as fast as "pure" JAX code. Haliax is also built to be scalable: it can support Fully-Sharded Data Parallelism (FSDP) and Tensor Parallelism with just a few lines of code. Haliax powers Levanter, our companion library for training large language models and other foundation models, with scale proven up to 70B parameters and up to TPU v4-2048.

Example: Attention

Here's a minimal attention module implementation in Haliax. For a more detailed introduction, please see the Haliax tutorial. (We use the excellent Equinox library for its module system and tree transformations.)

import equinox as eqx
import jax
import jax.numpy as jnp
import haliax as hax
import haliax.nn as hnn

Pos = hax.Axis("position", 1024)  # sequence length
KPos = Pos.alias("key_position")
Head = hax.Axis("head", 8)  # number of attention heads
Key = hax.Axis("key", 64)  # key size
Embed = hax.Axis("embed", 512)  # embedding size

# alternatively:
#Pos, KPos, Head, Key, Embed = hax.make_axes(pos=1024, key_pos=1024, head=8, key=64, embed=512)


def attention_scores(Key, KPos, query, key, mask):
    # how similar is each query to each key
    scores = hax.dot(query, key, axis=Key) / jnp.sqrt(Key.size)

    if mask is not None:
        scores -= 1E9 * (1.0 - mask)

    # convert to probabilities
    scores = haliax.nn.softmax(scores, KPos)
    return scores


def attention(Key, KPos, query, key, value, mask):
    scores = attention_scores(Key, KPos, query, key, mask)
    answers = hax.dot(scores, value, axis=KPos)

    return answers


# Causal Mask means that if pos >= key_pos, then pos can attend to key_pos
causal_mask = hax.arange(Pos).broadcast_axis(KPos) >= hax.arange(KPos)


class Attention(eqx.Module):
    proj_q: hnn.Linear  # [Embed] -> [Head, Key]
    proj_k: hnn.Linear  # [Embed] -> [Head, Key]
    proj_v: hnn.Linear  # [Embed] -> [Head, Key]
    proj_answer: hnn.Linear  # output projection from [Head, Key] -> [Embed]

    @staticmethod
    def init(Embed, Head, Key, *, key):
        k_q, k_k, k_v, k_ans = jax.random.split(key, 4)
        proj_q = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_q)
        proj_k = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_k)
        proj_v = hnn.Linear.init(In=Embed, Out=(Head, Key), key=k_v)
        proj_answer = hnn.Linear.init(In=(Head, Key), Out=Embed, key=k_ans)
        return Attention(proj_q, proj_k, proj_v, proj_answer)

    def __call__(self, x, mask=None):
        q = self.proj_q(x)
        # Rename "position" to "key_position" for self attention
        k = self.proj_k(x).rename({"position": "key_position"})
        v = self.proj_v(x).rename({"position": "key_position"})

        answers = attention(Key, KPos, q, k, v, causal_mask)

        x = self.proj_answer(answers)
        return x

Haliax was created by Stanford's Center for Research on Foundation Models (CRFM)'s research engineering team. You can find us in the #levanter channel on the unofficial Jax LLM Discord.

Documentation

Tutorials

These are some tutorials to get you started with Haliax. They are available as Colab notebooks:

API Reference

Haliax's API documentation is available at haliax.readthedocs.io.

Contributing

We welcome contributions! Please see CONTRIBUTING.md for more information. We also have a list of good first issues to help you get started. (If those don't appeal, don't hesitate to reach out to us on Discord!)

License

Haliax is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

haliax-1.4.dev404.tar.gz (798.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

haliax-1.4.dev404-py3-none-any.whl (127.3 kB view details)

Uploaded Python 3

File details

Details for the file haliax-1.4.dev404.tar.gz.

File metadata

  • Download URL: haliax-1.4.dev404.tar.gz
  • Upload date:
  • Size: 798.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for haliax-1.4.dev404.tar.gz
Algorithm Hash digest
SHA256 523d8ed2835ee88c5a453c8fe68d4975dca8f018a6dbfb87d0bc02c957c192c1
MD5 a472132ccac7bdd1a43ef1cafe3a052e
BLAKE2b-256 e5b3c588006fe15d0706aeac3780254c30645d3347d2b6271687ab44dc04a7f1

See more details on using hashes here.

Provenance

The following attestation bundles were made for haliax-1.4.dev404.tar.gz:

Publisher: publish_dev.yaml on stanford-crfm/haliax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file haliax-1.4.dev404-py3-none-any.whl.

File metadata

  • Download URL: haliax-1.4.dev404-py3-none-any.whl
  • Upload date:
  • Size: 127.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for haliax-1.4.dev404-py3-none-any.whl
Algorithm Hash digest
SHA256 ea5830d125076b7bd9caa83e2fc73e8c1b5cf55a8d308a1529c52f8f0c7ebda1
MD5 7e047d600ff169117ced2b7b7ef55a96
BLAKE2b-256 c4cc175511681fbc11f56c08a74987a3d6c4e713a890b2cb2fb27ba3b5f3c926

See more details on using hashes here.

Provenance

The following attestation bundles were made for haliax-1.4.dev404-py3-none-any.whl:

Publisher: publish_dev.yaml on stanford-crfm/haliax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page