Skip to main content

Add your description here

Project description

Triton Lib

Triton Lib (tlib) is an python library providing universal einstien notation functionality to triton kernels, along with a functional expansion on triton-lang's frontend. Tlib is written purely in python and is compatibable within functions decorated with @triton.jit. The design of this library is two-fold:

  • Provide an expanded set of base functional operations (tlf) for triton, which are numericaly stable and allow masking: tlf.{mean|var|std|mse|kl_div|...}

  • Provide an einx style ops syntax to all base and expanded triton ops, which are dynamically generated at compile time to incur no overhead within kernels.

Tlib is built off of the einx syntax and compiler, with a few major changes to enable compatability with triton. This means that all functions are dynamically generated and then compiled with python's exec() during triton's compile-time, creating no bottlenecks during kernel runtime.

Getting Started

  • Installation (COMING SOON)
  • Tutorial (COMING SOON)
  • Einx Notation
  • API Reference (COMING SOON)

Installation

Tlib is built using new features from triton 3.4.0, which is only compatible with torch >= 2.8.0. Tlib can be installed using the following pip command

pip install triton-lib

or built from source using the following command

git clone https://github.com/Hprairie/tlib.git
cd tlib
pip install -e .

What doese tlib look like in kernels?

Tlib provideds ops for almost all base triton frontend ops and more added on by tlf (tlib.funtional).

import triton
import triton.language as tl
import tlib
import tlib.functional as tlf

@triton.jit
def my_kernel(x_ptr, y_ptr, o_ptr, LENGTH: tl.constexpr):
    # arange indexing ops
    x = tl.load(x_ptr + tlib.arange("a b", tlib.dict(a=LENGTH, b=LENGTH)))
    y = tl.load(x_ptr + tlib.arange("a b c", tlib.dict(a=LENGTH, b=LENGTH, c=LENGTH)))

    # Rearrange ops
    o = tlib.rearrange("a b -> b a", x) # This is equivalent to tl.trans(x, (1, 0))
    o = tlib.rearrange("a b -> (a b)", x) # This is equivalent to tl.reshape(x, (LENGTH * LENGTH,))
    o = tlib.rearrange("a b c -> c (a b)", y) # This is equivalent to tl.reshape(x, (LENGTH, LENGTH * LENGTH)) followed by tl.trans(x, (0, 1))

    # Unary Ops
    o = tlib.cumsum("a [b]", x) # This is equivalent to tl.cumsum(x, axis=1)
    o = tlib.cumprod("a [b] c", y) # This is equivalent to tl.cumprod(x, axis=1)
    o = tlib.flip("[a] b", x) # This is equivalent to tl.flip(x, axis=0)
    o = tlib.sort("a b [c]", y) # This is equivalent to tl.sort(x, axis=2)
    o = tlib.softmax("a [b]", x) # This is equivalent to tl.softmax(x, axis=1)

    # Binary Ops
    o = tlib.add("a b, a b c", (x, y)) # This is equivalent to x[:, :, None] + y
    o = tlib.add("a c, a b c", (x, y)) # This is equivalent to x[:, None, :] + y
    o = tlib.add("b c, a b c", (x, y)) # This is equivalent to x[None, :, :] + y
    o = tlib.subtract("a b, a b c", (x, y)) # This is equivalent to x[:, :, None] - y
    o = tlib.multiply("a b, a b c", (x, y)) # This is equivalent to x[:, :, None] * y
    o = tlib.divide("a b, a b c", (x, y)) # This is equivalent to x[:, :, None] / y

    # Reduction Ops
    out = tlib.sum("a [b]", x) # This is equivalent to tl.sum(x, axis=1)
    out = tlib.mean("a [b]", x) # This is equivalent to tlf.mean(x, axis=1)
    out = tlib.var("a [b]", x) # This is equivalent to tlf.var(x, axis=1)
    out = tlib.count_nonzero("a [b]", x) # This is equivalent to tlf.count_nonzero(x, axis=1)
    out = tlib.max("a [b]", x) # This is equivalent to tl.max(x, axis=1)
    out = tlib.min("a [b]", x) # This is equivalent to tl.min(x, axis=1)
    out = tlib.argmax("a [b]", x) # This is equivalent to tl.argmax(x, axis=1)
    out = tlib.argmin("a [b]", x) # This is equivalent to tl.argmin(x, axis=1)

Why create/use Tlib

I will discuss, both ops and functional libraries added in tlib. Adding einstein notation ops to triton seemed like a no brainer. The readability of einstein notation in other high level frameworks such as torch, tensorfloew, jax, etc., makes it an incredibly appealing tool. Porting this functionality to triton, where we can evalue each expression at compile time convert it directly to tl syntax, makes it have features of high level abstractions without the performace reduction created by them.

Furthermore, on my quest to improve readability, I strongly desired to expand on the functionality of tl base language. I really desired to have the same functionality as torch but in triton. The best way to do this was to implement standard triton.jit functions for new functional values.

Limitations

As you might have noticed from the examples, there are some API differences between tlib and einx/einops. First, when passing multiple tensors, say to tlib.rearrange, we need to wrap them in a tuple object.

o = rearrange("a b c, d e f -> a c b, d f e", x, y)
o = tlib.rearrange("a b c, d e f -> a c b, d f e", (x, y))

Additionally, dictionaries aren't supported in triton, thus I have created a wrapper: tlib.dict, which functions like a dictionary, but is a tl.constexpr.

Misc

This section will eventually be moved, but outlined are the current roadmap for functionality and the limitations of triton lib

ToDo

  • Implement dot einstein notation ops
  • Build a PyPi package
  • Create Documentation
  • Fix associative scan operation in tlib

References

This package is partly built on einx, whose copyright has been added into the project and added upon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

triton_lib-0.0.1.tar.gz (61.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

triton_lib-0.0.1-py3-none-any.whl (72.7 kB view details)

Uploaded Python 3

File details

Details for the file triton_lib-0.0.1.tar.gz.

File metadata

  • Download URL: triton_lib-0.0.1.tar.gz
  • Upload date:
  • Size: 61.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for triton_lib-0.0.1.tar.gz
Algorithm Hash digest
SHA256 9da22e46cc0aa8411ef5110e7b012154867df0d61a5e1966e3852a77b9ce1449
MD5 bb3d7bb1c214f2313b8bfc428d55b94b
BLAKE2b-256 aff45e56b4d9f738a023a437b669e7ec77f8329fd8ddefa8bfc899c4ed90b2d8

See more details on using hashes here.

Provenance

The following attestation bundles were made for triton_lib-0.0.1.tar.gz:

Publisher: release.yml on Hprairie/tlib

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file triton_lib-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: triton_lib-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 72.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for triton_lib-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 69fa6910e72b5cffad219e6d5d273f3d7a2ecf9c545bd0f24175fac3860b10ba
MD5 36b303b943f0329f55dd90ee5bb40224
BLAKE2b-256 8f3e2b4801875b14bcd2abbc5602f4d6ea52d77134d260094ca40755f4ceecf0

See more details on using hashes here.

Provenance

The following attestation bundles were made for triton_lib-0.0.1-py3-none-any.whl:

Publisher: release.yml on Hprairie/tlib

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page