Skip to main content

A library for test-time training kernels written in thunderkittens

Project description

TTT-tk

TTT-tk is a repository for test-time training kernels written in Thunderkittens.

Currently, we only support non-causal TTT-MLP kernels with head dimension of 64. Remat is automatically supported with these kernels.

Here is an example on how to invoke the kernels.

import ttt_tk


# Both ttt-mlp
ttt_tk.ttt_forward(
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    W1_init.contiguous(),
    b1_init.contiguous(),
    W2_init.contiguous(),
    b2_init.contiguous(),
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    checkpoint_group_size
)

ttt_tk.ttt_backward(
    # Forward inputs
    XQ_batch.contiguous(),
    XK_batch.contiguous(),
    XV_batch.contiguous(),
    last_eta.contiguous(),
    ttt_norm_weight.contiguous(),
    ttt_norm_bias.contiguous(),
    # Checkpoints
    W1_checkpoints.contiguous(),
    b1_checkpoints.contiguous(),
    W2_checkpoints.contiguous(),
    b2_checkpoints.contiguous(),
    XQW_batch.contiguous(),
    # Rematted Buffers
    W1_init_group.contiguous(),
    b1_init_group.contiguous(),
    W2_init_group.contiguous(),
    b2_init_group.contiguous(),
    x_hat_ln_group.contiguous(),
    std_ln_group.contiguous(),
    X2_group.contiguous(),
    Z1_group.contiguous(),
    Z1_bar_group.contiguous(),
    X2_bar_group.contiguous(),
    grad_l_wrt_Z2_group.contiguous(),
    grad_l_wrt_Z1_group.contiguous(),
    x_hat_fused_group.contiguous(),
    grad_x_hat_fused_group.contiguous(),
    grad_output_fused_group.contiguous(),
    std_fused_group.contiguous(),
    # Upstream grads
    grad_L_W1_last.contiguous(),
    grad_L_b1_last.contiguous(),
    grad_L_W2_last.contiguous(),
    grad_L_b2_last.contiguous(),
    grad_L_XQW_batch.contiguous(),
    # Output grads
    grad_L_ttt_norm_weight.contiguous(),
    grad_L_ttt_norm_bias.contiguous(),
    grad_L_W1_init.contiguous(),
    grad_L_b1_init.contiguous(),
    grad_L_W2_init.contiguous(),
    grad_L_b2_init.contiguous(),
    grad_L_last_eta.contiguous(),
    grad_L_XQ.contiguous(),
    grad_L_XK.contiguous(),
    grad_L_XV.contiguous(),
    checkpoint_group_size
)

Note that these kernels do not support non-contiguous tensors.

Thunderkittens

This repository is forked from Thunderkittens (https://github.com/HazyResearch/ThunderKittens). Thunderkittens was used and modified for kernel development.

Installation

Installation requires CUDA drivers or toolkit (v 12.3+) and g++ v10+

Pip

pip install ttt_tk

From source

source env.src
python setup.py install

Notes on implementation

These kernels use distributed shared memory to implement tensor-parallelism and sharding. The hidden states are sharded across SMs to save shared memory.

These kernels also use input staging and pipelining to hide latencies for global reads.

We also used mixed precision to perform the matmuls in bf16 for tensor core usage and also kept hidden state (and grads) accumulation and layer norm computation in float32.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ttt_tk-0.2.0.tar.gz (6.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ttt_tk-0.2.0-cp312-cp312-manylinux1_x86_64.whl (1.8 MB view details)

Uploaded CPython 3.12

File details

Details for the file ttt_tk-0.2.0.tar.gz.

File metadata

  • Download URL: ttt_tk-0.2.0.tar.gz
  • Upload date:
  • Size: 6.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ttt_tk-0.2.0.tar.gz
Algorithm Hash digest
SHA256 4e9c9433ae89bc7548830a55599c40a5eb71d3799a85567250686621321e37d5
MD5 2a7df70dc73354084b8bfa7ae03bc9e1
BLAKE2b-256 7e093ad4ad6a3beea0244c52ac8bd99c0f79cff3d412806ac4beaaaaf2ac9405

See more details on using hashes here.

File details

Details for the file ttt_tk-0.2.0-cp312-cp312-manylinux1_x86_64.whl.

File metadata

File hashes

Hashes for ttt_tk-0.2.0-cp312-cp312-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 de396581b499169cd31c1d201157a19db22d70b130395b2192ae1a8a868cab27
MD5 e209eb28e22eda3872dbbd6073226e9d
BLAKE2b-256 78eadd0d7226b629fc0228e8066502fdf5b2a6373b9f1bba1d7872414455be40

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page