A library for test-time training.
Project description
TTT MLP
TTT is a repository for test-time training kernels.
Currently, we only support non-causal TTT-MLP kernels with head dimension of 64. Remat is automatically supported with these kernels.
Here is an example on how to invoke the kernels.
import ttt_mlp
# Both ttt-mlp
ttt_mlp.ttt_forward(
XQ_batch.contiguous(),
XK_batch.contiguous(),
XV_batch.contiguous(),
last_eta.contiguous(),
ttt_norm_weight.contiguous(),
ttt_norm_bias.contiguous(),
W1_init.contiguous(),
b1_init.contiguous(),
W2_init.contiguous(),
b2_init.contiguous(),
W1_checkpoints.contiguous(),
b1_checkpoints.contiguous(),
W2_checkpoints.contiguous(),
b2_checkpoints.contiguous(),
XQW_batch.contiguous(),
checkpoint_group_size
)
ttt_mlp.ttt_backward(
# Forward inputs
XQ_batch.contiguous(),
XK_batch.contiguous(),
XV_batch.contiguous(),
last_eta.contiguous(),
ttt_norm_weight.contiguous(),
ttt_norm_bias.contiguous(),
# Checkpoints
W1_checkpoints.contiguous(),
b1_checkpoints.contiguous(),
W2_checkpoints.contiguous(),
b2_checkpoints.contiguous(),
XQW_batch.contiguous(),
# Rematted Buffers
W1_init_group.contiguous(),
b1_init_group.contiguous(),
W2_init_group.contiguous(),
b2_init_group.contiguous(),
x_hat_ln_group.contiguous(),
std_ln_group.contiguous(),
X2_group.contiguous(),
Z1_group.contiguous(),
Z1_bar_group.contiguous(),
X2_bar_group.contiguous(),
grad_l_wrt_Z2_group.contiguous(),
grad_l_wrt_Z1_group.contiguous(),
x_hat_fused_group.contiguous(),
grad_x_hat_fused_group.contiguous(),
grad_output_fused_group.contiguous(),
std_fused_group.contiguous(),
# Upstream grads
grad_L_W1_last.contiguous(),
grad_L_b1_last.contiguous(),
grad_L_W2_last.contiguous(),
grad_L_b2_last.contiguous(),
grad_L_XQW_batch.contiguous(),
# Output grads
grad_L_ttt_norm_weight.contiguous(),
grad_L_ttt_norm_bias.contiguous(),
grad_L_W1_init.contiguous(),
grad_L_b1_init.contiguous(),
grad_L_W2_init.contiguous(),
grad_L_b2_init.contiguous(),
grad_L_last_eta.contiguous(),
grad_L_XQ.contiguous(),
grad_L_XK.contiguous(),
grad_L_XV.contiguous(),
checkpoint_group_size
)
Note that these kernels do not support non-contiguous tensors.
Thunderkittens
This repository is forked from Thunderkittens (https://github.com/HazyResearch/ThunderKittens). Thunderkittens was used and modified for kernel development.
Installation
Installation requires CUDA drivers or toolkit (v 12.3+) and g++ v10+
Pip
pip install ttt_mlp
From source
source env.src
python setup.py install
Notes on implementation
These kernels use distributed shared memory to implement tensor-parallelism and sharding. The hidden states are sharded across SMs to save shared memory.
These kernels also use input staging and pipelining to hide latencies for global reads.
We also used mixed precision to perform the matmuls in bf16 for tensor core usage and also kept hidden state (and grads) accumulation and layer norm computation in float32.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ttt_mlp-0.3.0.tar.gz.
File metadata
- Download URL: ttt_mlp-0.3.0.tar.gz
- Upload date:
- Size: 86.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
81ca561d62c6233b1230bca6c341ccd2c04c05d560d7bf7aee2620aec016463e
|
|
| MD5 |
ffc29b2abcdf5bec8e4e86d794b283e7
|
|
| BLAKE2b-256 |
cc45cf376b4dc8d67a6b3b105705a5448b496414f6ef456d66a0d0da34425520
|
File details
Details for the file ttt_mlp-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl.
File metadata
- Download URL: ttt_mlp-0.3.0-cp312-cp312-manylinux_2_37_x86_64.whl
- Upload date:
- Size: 10.3 MB
- Tags: CPython 3.12, manylinux: glibc 2.37+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2e93e3ed147802942889acc54dd2e876b01c5c35b9cc5abdecb2e5e2f6172187
|
|
| MD5 |
d2cb9945bffaf2aeea73c4018473d18b
|
|
| BLAKE2b-256 |
0a92863dd682d4bb1f8054571322fc4f43ad177254ad783d5f8c16c1c0edd224
|