Skip to main content

A library for test-time training.

Project description

TTT

TTT is a repository for test-time training kernels.

Currently, we only support non-causal TTT-MLP kernels with head dimension of 64. Remat is automatically supported with these kernels.

Here is an example on how to invoke the kernels.

import test_time_training as ttt


# Both ttt-mlp
ttt.ttt_forward(
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    W1_init.contiguous(),
    b1_init.contiguous(),
    W2_init.contiguous(),
    b2_init.contiguous(),
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    checkpoint_group_size
)

ttt.ttt_backward(
    # Forward inputs
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    # Checkpoints
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    # Rematted Buffers
    W1_init_group.contiguous(),
    b1_init_group.contiguous(),
    W2_init_group.contiguous(),
    b2_init_group.contiguous(),
    x_hat_ln_group.contiguous(),
    std_ln_group.contiguous(),
    X2_group.contiguous(),
    Z1_group.contiguous(),
    Z1_bar_group.contiguous(),
    X2_bar_group.contiguous(),
    grad_l_wrt_Z2_group.contiguous(),
    grad_l_wrt_Z1_group.contiguous(),
    x_hat_fused_group.contiguous(),
    grad_x_hat_fused_group.contiguous(),
    grad_output_fused_group.contiguous(),
    std_fused_group.contiguous(),
    # Upstream grads
    grad_L_W1_last.contiguous(),
    grad_L_b1_last.contiguous(),
    grad_L_W2_last.contiguous(),
    grad_L_b2_last.contiguous(),
    grad_L_XQW_batch.contiguous(),
    # Output grads
    grad_L_ttt_norm_weight.contiguous(),
    grad_L_ttt_norm_bias.contiguous(),
    grad_L_W1_init.contiguous(),
    grad_L_b1_init.contiguous(),
    grad_L_W2_init.contiguous(),
    grad_L_b2_init.contiguous(),
    grad_L_last_eta.contiguous(),
    grad_L_XQ.contiguous(),
    grad_L_XK.contiguous(),
    grad_L_XV.contiguous(),
    checkpoint_group_size
)

Note that these kernels do not support non-contiguous tensors.

Thunderkittens

This repository is forked from Thunderkittens (https://github.com/HazyResearch/ThunderKittens). Thunderkittens was used and modified for kernel development.

Installation

Installation requires CUDA drivers or toolkit (v 12.3+) and g++ v10+

Pip

pip install test_time_training

From source

source env.src
python setup.py install

Notes on implementation

These kernels use distributed shared memory to implement tensor-parallelism and sharding. The hidden states are sharded across SMs to save shared memory.

These kernels also use input staging and pipelining to hide latencies for global reads.

We also used mixed precision to perform the matmuls in bf16 for tensor core usage and also kept hidden state (and grads) accumulation and layer norm computation in float32.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

test_time_training-0.3.0.tar.gz (6.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

test_time_training-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl (7.6 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.37+ x86-64

File details

Details for the file test_time_training-0.3.0.tar.gz.

File metadata

  • Download URL: test_time_training-0.3.0.tar.gz
  • Upload date:
  • Size: 6.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for test_time_training-0.3.0.tar.gz
Algorithm Hash digest
SHA256 43a3c1f5ed6f3ed81dcf36fb2ce1b5b8a2a388bcd34582cbf2ee3ba1fa735ea5
MD5 2f5a11468ca8623c24331f1192a976cb
BLAKE2b-256 5bf68f804efe6d2386e34c418d2038a98c6d4d3911844d7989ec12741077d7c8

See more details on using hashes here.

File details

Details for the file test_time_training-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl.

File metadata

File hashes

Hashes for test_time_training-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl
Algorithm Hash digest
SHA256 0d7272897f89c653b49cb50c84c91cff6e64232c9eaa5a9af161528bdf96dc30
MD5 e5bd7f88a2a22dd66e115b15ec6e0836
BLAKE2b-256 1aeb677d1ccd3f6fa08ac9b6933bb40682ddde6889f6b9a9da5dba58b9ecfecc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page