Skip to main content

A library for test-time training.

Project description

TTT MLP

TTT is a repository for test-time training kernels.

Currently, we only support non-causal TTT-MLP kernels with head dimension of 64. Remat is automatically supported with these kernels.

Here is an example on how to invoke the kernels.

import ttt_mlp


# Both ttt-mlp
ttt_mlp.ttt_forward(
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    W1_init.contiguous(),
    b1_init.contiguous(),
    W2_init.contiguous(),
    b2_init.contiguous(),
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    checkpoint_group_size
)

ttt_mlp.ttt_backward(
    # Forward inputs
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    # Checkpoints
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    # Rematted Buffers
    W1_init_group.contiguous(),
    b1_init_group.contiguous(),
    W2_init_group.contiguous(),
    b2_init_group.contiguous(),
    x_hat_ln_group.contiguous(),
    std_ln_group.contiguous(),
    X2_group.contiguous(),
    Z1_group.contiguous(),
    Z1_bar_group.contiguous(),
    X2_bar_group.contiguous(),
    grad_l_wrt_Z2_group.contiguous(),
    grad_l_wrt_Z1_group.contiguous(),
    x_hat_fused_group.contiguous(),
    grad_x_hat_fused_group.contiguous(),
    grad_output_fused_group.contiguous(),
    std_fused_group.contiguous(),
    # Upstream grads
    grad_L_W1_last.contiguous(),
    grad_L_b1_last.contiguous(),
    grad_L_W2_last.contiguous(),
    grad_L_b2_last.contiguous(),
    grad_L_XQW_batch.contiguous(),
    # Output grads
    grad_L_ttt_norm_weight.contiguous(),
    grad_L_ttt_norm_bias.contiguous(),
    grad_L_W1_init.contiguous(),
    grad_L_b1_init.contiguous(),
    grad_L_W2_init.contiguous(),
    grad_L_b2_init.contiguous(),
    grad_L_last_eta.contiguous(),
    grad_L_XQ.contiguous(),
    grad_L_XK.contiguous(),
    grad_L_XV.contiguous(),
    checkpoint_group_size
)

Note that these kernels do not support non-contiguous tensors.

Thunderkittens

This repository is forked from Thunderkittens (https://github.com/HazyResearch/ThunderKittens). Thunderkittens was used and modified for kernel development.

Installation

Installation requires CUDA drivers or toolkit (v 12.3+) and g++ v10+

Pip

pip install ttt_mlp

From source

source env.src
python setup.py install

Notes on implementation

These kernels use distributed shared memory to implement tensor-parallelism and sharding. The hidden states are sharded across SMs to save shared memory.

These kernels also use input staging and pipelining to hide latencies for global reads.

We also used mixed precision to perform the matmuls in bf16 for tensor core usage and also kept hidden state (and grads) accumulation and layer norm computation in float32.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ttt_mlp-0.3.0.tar.gz (86.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ttt_mlp-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl (10.3 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.37+ x86-64

File details

Details for the file ttt_mlp-0.3.0.tar.gz.

File metadata

  • Download URL: ttt_mlp-0.3.0.tar.gz
  • Upload date:
  • Size: 86.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ttt_mlp-0.3.0.tar.gz
Algorithm Hash digest
SHA256 81ca561d62c6233b1230bca6c341ccd2c04c05d560d7bf7aee2620aec016463e
MD5 ffc29b2abcdf5bec8e4e86d794b283e7
BLAKE2b-256 cc45cf376b4dc8d67a6b3b105705a5448b496414f6ef456d66a0d0da34425520

See more details on using hashes here.

File details

Details for the file ttt_mlp-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl.

File metadata

File hashes

Hashes for ttt_mlp-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl
Algorithm Hash digest
SHA256 2e93e3ed147802942889acc54dd2e876b01c5c35b9cc5abdecb2e5e2f6172187
MD5 d2cb9945bffaf2aeea73c4018473d18b
BLAKE2b-256 0a92863dd682d4bb1f8054571322fc4f43ad177254ad783d5f8c16c1c0edd224

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page