Progressive Action Tokenizer framework with FSQ quantization
Project description
PAT: Progressive Action Tokenizer
pat is a PyTorch training framework for a progressive action tokenizer inspired by OAT (vendor/oat/oat/tokenizer/oat/), with:
- register-based encoder
- FSQ action quantization
- progressive decoding via nested token dropout
- tokenizer training loop with validation, reconstruction eval, EMA, and checkpoints
Install
pip install -e .
Architecture compare
1️⃣ 模型结构对比
UAT ActionCodec
1 输入 (b, seq_len, action_dim)
2 ↓
3 Perceiver Encoder
4 ├── EmbodimentEmbedding (机器人类型嵌入)
5 ├── Input Projection (max_action_dim → encoder_dim)
6 ├── Positional Embedding (Fourier/Sinusoidal)
7 ├── CLS Tokens (可学习的查询向量)
8 └── Transformer Blocks (Cross-Attention + Self-Attention)
9 ↓
10 Latent (b, n_tokens_per_quantizer, z_dim)
11 ↓
12 Vector Quantizer (VQ/RVQ)
13 ↓
14 Perceiver Decoder
15 ├── EmbodimentEmbedding
16 ├── Input Projection (z_dim → decoder_dim)
17 ├── Positional Embedding
18 └── Transformer Blocks
19 ↓
20 输出 (b, seq_len, max_action_dim)
关键特点:
- CLS Token 机制:使用可学习的 CLS token 作为查询,通过 Cross-Attention 从输入中提取信息
- Embodiment-Aware:每个机器人类型有独立的 embedding,支持动态扩展
- Flexible Attention:支持因果掩码、自注意力可选配置
VQVAE
2 ↓
3 CNN Encoder
4 ├── Conv1d (in_channels → ch)
5 ├── DownSample Blocks (3 层,每层 2x 下采样)
6 │ ├── ConvBlock + Downsample
7 │ └── ConvBlock
8 └── Conv1d (→ z_channels)
9 ↓
10 Latent (b, z_channels, seq_len/8)
11 ↓
12 Residual Vector Quantizer (RVQ)
13 ↓
14 CNN Decoder
15 ├── Conv1d (z_channels → ch*8)
16 ├── UpSample Blocks (3 层,每层 2x 上采样)
17 │ ├── ConvBlock + Upsample
18 │ └── ConvBlock
19 └── Conv1d (→ in_channels)
20 ↓
21 输出 (b, seq_len, action_dim)
关键特点:
- 纯 CNN 架构:使用卷积进行时空特征提取
- 固定压缩率:时间维度压缩 8 倍 (3 次 2x 下采样)
- MultiVQVAE 变体:将动作分解为 pos/rot/grip 三个独立 VQVAE
2️⃣ 构建方法对比
UAT 配置方式
1 config = UATConfig(
2 embodiment_config={
3 "robot_A": {"action_dim": 7, "freq": 20, "duration": 1.0},
4 "robot_B": {"action_dim": 10, "freq": 15, "duration": 1.0},
5 },
6 n_tokens=16, # 总 token 数
7 n_quantizers=1, # VQ=1, RVQ≥2
8 z_dim=512, # 潜在空间维度
9 vq_type="vq", # "vq" 或 "rvq"
10 vq_codebook_size=2048,
11 encoder_dim=256,
12 encoder_n_layers=6,
13 encoder_n_heads=8,
14 decoder_dim=256,
15 decoder_n_layers=6,
16 )
17
18 model = ActionCodec(config)
动态扩展能力:
2 "robot_C": {"action_dim": 12, "freq": 25, "duration": 1.5}
3 })
VQVAE 配置方式
2 model = VQVAE(
3 input_dim=7, # 动作维度
4 embedding_dim=256, # VQ 嵌入维度
5 cnn_config={
6 "hidden_size": 64,
7 "output_size": 512,
8 "dropout": 0.0,
9 },
10 num_embeddings=2048, # codebook 大小
11 action_horizon=24, # 时间长度
12 n_codebooks=4, # RVQ 量化器数量
13 commitment_cost=0.25,
14 )
15
16 # MultiVQVAE (动作分解)
17 model = MultiVQVAE(
18 input_dim={
19 "pos": 6, # 位置
20 "rot": 12, # 旋转 (6D)
21 "grip": 2, # 夹爪
22 },
23 embedding_dim=256,
24 cnn_config={...},
25 n_codebooks={"pos": 6, "rot": 3, "grip": 1},
26 )
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pat_vla-0.1.2.tar.gz
(89.1 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
pat_vla-0.1.2-py3-none-any.whl
(73.9 kB
view details)
File details
Details for the file pat_vla-0.1.2.tar.gz.
File metadata
- Download URL: pat_vla-0.1.2.tar.gz
- Upload date:
- Size: 89.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d67cc39cff67ae4cdc2fdc61bc2f9dbff0dafc9ef2522caf2b6618e4adf6202f
|
|
| MD5 |
830bccebd5ba6343b37ebada8b119191
|
|
| BLAKE2b-256 |
db101e2331b15bec93670614925b8f99b8c8b71534868769360a90c3238702f2
|
File details
Details for the file pat_vla-0.1.2-py3-none-any.whl.
File metadata
- Download URL: pat_vla-0.1.2-py3-none-any.whl
- Upload date:
- Size: 73.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a74730c7c8526530092989e6a3778dbaa720f72e6d9865087d760c09fd2ec5b
|
|
| MD5 |
e1331c6794494257daea4a517cecc6f1
|
|
| BLAKE2b-256 |
37211de878a8dbbabcc5d11d606c75545e1f4cf688ae67540fa969f99cced6bd
|