Skip to main content

A library for test-time training.

Project description

TTT

TTT is a repository for test-time training kernels.

Currently, we only support non-causal TTT-MLP kernels with head dimension of 64. Remat is automatically supported with these kernels.

Here is an example on how to invoke the kernels.

import test_time_training as ttt


# Both ttt-mlp
ttt.ttt_forward(
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    W1_init.contiguous(),
    b1_init.contiguous(),
    W2_init.contiguous(),
    b2_init.contiguous(),
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    checkpoint_group_size
)

ttt.ttt_backward(
    # Forward inputs
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    # Checkpoints
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    # Rematted Buffers
    W1_init_group.contiguous(),
    b1_init_group.contiguous(),
    W2_init_group.contiguous(),
    b2_init_group.contiguous(),
    x_hat_ln_group.contiguous(),
    std_ln_group.contiguous(),
    X2_group.contiguous(),
    Z1_group.contiguous(),
    Z1_bar_group.contiguous(),
    X2_bar_group.contiguous(),
    grad_l_wrt_Z2_group.contiguous(),
    grad_l_wrt_Z1_group.contiguous(),
    x_hat_fused_group.contiguous(),
    grad_x_hat_fused_group.contiguous(),
    grad_output_fused_group.contiguous(),
    std_fused_group.contiguous(),
    # Upstream grads
    grad_L_W1_last.contiguous(),
    grad_L_b1_last.contiguous(),
    grad_L_W2_last.contiguous(),
    grad_L_b2_last.contiguous(),
    grad_L_XQW_batch.contiguous(),
    # Output grads
    grad_L_ttt_norm_weight.contiguous(),
    grad_L_ttt_norm_bias.contiguous(),
    grad_L_W1_init.contiguous(),
    grad_L_b1_init.contiguous(),
    grad_L_W2_init.contiguous(),
    grad_L_b2_init.contiguous(),
    grad_L_last_eta.contiguous(),
    grad_L_XQ.contiguous(),
    grad_L_XK.contiguous(),
    grad_L_XV.contiguous(),
    checkpoint_group_size
)

Note that these kernels do not support non-contiguous tensors.

Thunderkittens

This repository is forked from Thunderkittens (https://github.com/HazyResearch/ThunderKittens). Thunderkittens was used and modified for kernel development.

Installation

Installation requires CUDA drivers or toolkit (v 12.3+) and g++ v10+

Pip

pip install test_time_training

From source

source env.src
python setup.py install

Notes on implementation

These kernels use distributed shared memory to implement tensor-parallelism and sharding. The hidden states are sharded across SMs to save shared memory.

These kernels also use input staging and pipelining to hide latencies for global reads.

We also used mixed precision to perform the matmuls in bf16 for tensor core usage and also kept hidden state (and grads) accumulation and layer norm computation in float32.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

test_time_training-0.6.0.tar.gz (6.8 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

test_time_training-0.6.0-cp312-cp312-manylinux_2_37_x86_64.whl (7.6 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.37+ x86-64

test_time_training-0.6.0-cp312-cp312-manylinux_2_34_x86_64.whl (7.6 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

test_time_training-0.6.0-cp312-cp312-manylinux1_x86_64.whl (917.4 kB view details)

Uploaded CPython 3.12

File details

Details for the file test_time_training-0.6.0.tar.gz.

File metadata

  • Download URL: test_time_training-0.6.0.tar.gz
  • Upload date:
  • Size: 6.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for test_time_training-0.6.0.tar.gz
Algorithm Hash digest
SHA256 5fc3237c05b1176ca3054c7ffd533a31ab699afaa8edc7683e52f11d7c0e439c
MD5 ad5f85571fbfd53838343b7f8216e570
BLAKE2b-256 1619b66cba0ced233aa673be6aa9a6f679ff0c4c10b5c128d3213ffc091d2924

See more details on using hashes here.

File details

Details for the file test_time_training-0.6.0-cp312-cp312-manylinux_2_37_x86_64.whl.

File metadata

File hashes

Hashes for test_time_training-0.6.0-cp312-cp312-manylinux_2_37_x86_64.whl
Algorithm Hash digest
SHA256 8937750f7f99023c6daf12e285b6f6a6854de46963dc76942787d885d1322d4d
MD5 44156540b21d877f60c666bb279e4cea
BLAKE2b-256 67118b35283445cf876d75eed31ea2ce9b03f4eae14443f10b17360a12e0295a

See more details on using hashes here.

File details

Details for the file test_time_training-0.6.0-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for test_time_training-0.6.0-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 cb65fc347a5e960bb2fca25de6ccc7c1d6ad9c5e79c13d97de38627c7cc44559
MD5 5f6a4f50e12d1d471d3c5f6ef246e039
BLAKE2b-256 425b99a8c4f761804327faef0e7e0decf88d692ce0d6ce241fe7b80171457185

See more details on using hashes here.

File details

Details for the file test_time_training-0.6.0-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for test_time_training-0.6.0-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ca75e066e4d1777d9900c9d11248b84000f903173b6f20400f0e25b991398f04
MD5 88b3f8a8e4c29b77914923c5ec5593e7
BLAKE2b-256 79fe00205070e345edaec9ae0c0d5266defe7790d27e10b4879bcdf003e563ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page