Skip to main content

An implementation of TRM

Project description

Fast, easy-to-play-with Tiny Recursive Models

This is an implementation of the Tiny Recursive Model (TRM).

Train a TRM in a few minutes on an A10. Reproduce the official TRM results. Push the envelope.

Sudoku thinking

Motivation

Recently, recursive models made a big comeback, notably with Tiny Recursive Models, which won 1st paper award at the ARC-AGI 2 contest and maybe more importantly, reached an impressive level of performance on several benchmarks such as ARC-AGI 2 and Sudoku Extreme. TRM brings a lot of simplicity to its ancestor, HRM. However, the codebase inherits much of HRM's legacy.

We propose a clean implementation of TRM. We call it "nano" because it is easy to experiment with, yet incorporates all important implementation details of TRM. The project uses hydra, torch lightning and uv to make experimentation easy. We propose an in-code introductory video and small datasets (Sudoku 4x4 and 6x6) that lets you train a TRM on an A10 in a few minutes!

This repo reproduces the results on Sudoku Extreme and Maze Hard (87% and 75% exact accuracy on validation, respectively). We hope you will find this repo useful for your own experimentation on TRM.

Installation

This repo comes with uv. You just need to run uv run python ... commands and everything will be installed automagically on the first run.

Sudoku Extreme

Generate data:

uv run python scripts/data/build_sudoku_extreme_dataset.py --output-dir ./data/sudoku_extreme_1k_aug_1k --subsample-size 1000 --num-aug 1000 --eval-ratio 0.01

Run a training:

uv run python src/nn/train.py experiment=trm_sudoku_extreme_1k_aug_1k

Training time ~1h on an H100 SXM5. You should get to ~87% exact accuracy on validation (same as the reference implementation)

Maze Hard

Generate data:

uv run python scripts/data/build_maze_dataset.py --output-dir ./data/maze-30x30-hard-1k --num-aug 0 --eval-ratio 1.0

Run a training:

uv run python src/nn/train.py experiment=trm_maze

Training time ~2h on an H100 SXM5. You should get to ~75% exact accuracy on validation (same as the reference implementation)

Small Sudoku datasets

Generate data:

bash bash/generate_sudoku_data.sh -> choose which dataset you want to generate

Run a training:

uv run python src/nn/train.py experiment=trm_sudoku_4x4

This take a few minutes on a A10!

ARC-AGI

Download the data from the kaggle challenge page.

Visualizations

Sudoku:

  • Evaluate and generate a gif: uv run python src/nn/evaluate.py +checkpoint=./checkpoints/smooth-sunset-204.ckpt +data_dir=./data/sudoku-extreme-1k-aug-1k +visualize=true +save_gif=true +min_steps=9

Maze:

  • Generate a dataset with test data (not just val): uv run python scripts/data/build_maze_dataset.py --output-dir ./data/maze-30x30-hard-1k --num-aug 0 --eval-ratio 0.5
  • Evaluate and generate a gif: uv run python src/nn/evaluate.py +checkpoint=./checkpoints/stellar-shape-194.ckpt +data_dir=./data/maze-30x30-hard-1k +visualize=true +save_gif=true +min_steps=9 +task=maze

Technical Notes

Codebase structure

  • src/nn/train.py: main training script
  • src/nn/models/trm.py: TRM model
  • src/nn/configs/experiments: main experimentation configurations.

Installing AdamATan2

This project uses the vanilla AdamW optimizer. If you want AdamATan2 and struggle to install it, here is how to install it from source:

# Clone the repo
cd /tmp
git clone https://github.com/imoneoi/adam-atan2.git

cd adam-atan2/

uv pip install --python /home/ubuntu/nano-trm/.venv/bin/python --verbose --no-cache-dir --no-build-isolation -e .

# Test
cd /home/ubuntu/nano-trm
uv run python
from adam_atan2 import AdamATan2 -> this should work

AdamAtan2 is not needed to reproduce the official results on Sudoku Extreme and Maze Hard. It might be needed on harder problems / deeper models (e.g. ARC AGI 2)

Contact

Follow me on X

My homepage

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nano_trm-0.1.0.tar.gz (11.9 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nano_trm-0.1.0-py3-none-any.whl (129.5 kB view details)

Uploaded Python 3

File details

Details for the file nano_trm-0.1.0.tar.gz.

File metadata

  • Download URL: nano_trm-0.1.0.tar.gz
  • Upload date:
  • Size: 11.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for nano_trm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9355ecd8363da9b57773d5a240b9af59e7ece1d7bef0f0c1c23072e63b169249
MD5 f0860510f11c7d6981e70057b10a02f3
BLAKE2b-256 735c11f3049a57e74cfa1c3ef4482961512c2c25fdf3d2c55268b9bc3c82c9f4

See more details on using hashes here.

File details

Details for the file nano_trm-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: nano_trm-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 129.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for nano_trm-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0e92e78e80075dc04b3c24005efb4dc9081462c599e2d3dc411b5edfe7e9fd32
MD5 df47726139203362a215350ae3b7c59f
BLAKE2b-256 e7e05c704c0987b2cef1dc0979761cf604e0a883dfdf3cc44ecb556936747b4b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page