Skip to main content

Export PyTorch modules to Scratch sprites

Project description

cattorch

Export PyTorch neural networks to Scratch sprites.

cattorch transpiles a torch.nn.Module into a .sprite3 file that can be imported directly into any Scratch project. The generated sprite uses only standard Scratch blocks, so no extensions or modifications are required.

cattorch does not export training scripts, you will need to train your model with torch before exporting to a Scratch sprite.

Install

pip install cattorch

Requires Python 3.10+ and PyTorch 2.0+.

Usage

import torch
import torch.nn as nn
from cattorch import transpile

class TwoLayerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.fc2 = nn.Linear(8, 3)

    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = TwoLayerNet()
# train your model first! then:
# transpile(model, example input, sprite name)
transpile(model, torch.randn(1, 4), "two_layer_net")
# => two_layer_net.sprite3

# optionally reduce file size by rounding weights
transpile(model, torch.randn(1, 4), "two_layer_net", sig_figs=6)

cattorch uses torch.export under the hood, which requires a single code path with no data-dependent control flow. If your model has conditional returns (e.g. returning loss during training), add an inference-only forward method:

# won't work: conditional return
def forward(self, x, targets=None):
    logits = self.head(x)
    if targets is None:
        return logits
    return logits, F.cross_entropy(logits, targets)

# will work: single return path
def forward_inference(self, x):
    return self.head(x)

model.eval()
model.forward = model.forward_inference
transpile(model, example_input, "my_model")

Some modules (e.g. HuggingFace transformer blocks) return tuples instead of plain tensors. torch.export will fail if a downstream layer receives a tuple where it expects a tensor. Unpack the output in your wrapper's forward method:

# won't work: block returns (hidden_states, attention_weights, ...)
x = block(x)

# will work: extract the tensor you need
x = block(x)[0]

In Scratch, the sprite reads its input from a list called input and writes results to a list called output. It is up to you to add logic to fill the input tensor and run the generated code blocks.

If the model takes multiple input tensors, the additional inputs are named input_1, input_2, etc.

Supported operations

Category Operations
Convolution nn.Conv1d, nn.Conv2d (with and without bias, stride, padding)
Pooling nn.MaxPool1d/2d, nn.AvgPool1d/2d, nn.AdaptiveAvgPool2d
Linear layers nn.Linear (with and without bias)
Matrix multiply @ / torch.matmul
Activations F.relu, torch.sigmoid, torch.tanh, F.gelu (tanh approx. only), F.silu, F.leaky_relu, F.elu
Normalization nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.RMSNorm, torch.rsqrt
Softmax F.softmax (any dim)
Embedding nn.Embedding
Masking masked_fill (for causal attention masks via register_buffer)
Arithmetic +, -, * (tensor and scalar), / (scalar), unary -, torch.pow
Reduction torch.mean (along a dim)
Tensor creation torch.arange, torch.ones, torch.zeros, torch.full, torch.ones_like, torch.zeros_like
Shape view, reshape, flatten, contiguous, clone (no-ops on flat data)
Transpose transpose, permute, .T (arbitrary dimensions)
Split / Chunk split, split_with_sizes, chunk
Concatenation torch.cat (any dim, any number of inputs)
Slice tensor[:n] style slicing along any dimension

These are sufficient for architectures like MLPs, CNNs, and transformer LLMs, including multi-head attention, combined QKV projections, rotary position embeddings (RoPE), causal masking, pre-norm blocks with residual connections, and SwiGLU-style gating. RNN support is planned for the future.

Tokenizers

cattorch can also transpile HuggingFace tokenizers into Scratch sprites, so the full text → token IDs → model → token IDs → text pipeline can run inside a Scratch project. Two tokenizer types are supported:

  • CharTokenizer — character-level lookup. Each character maps to one ID.
  • BPETokenizer — byte-pair encoding. Merges are applied iteratively over the full input string, including spaces.

Off-the-shelf tokenizers from large models will not work here. Production tokenizers like GPT-2's or Llama's use byte-level pre-tokenization, regex splits, and other preprocessing steps that the Scratch templates don't implement, and their 30k–100k+ token vocabularies would cause embeddings to blow past Scratch's 200,000 list item limit. In practice you'll want to train a small custom BPE tokenizer on your own corpus (with no pre-tokenizer), so BPE operates on the raw input string, sized to match the small model you're transpiling.

from transformers import AutoTokenizer
from cattorch import CharTokenizer, BPETokenizer

tokenizer = AutoTokenizer.from_pretrained("my-model")
BPETokenizer(tokenizer).save("my_tokenizer")
# => my_tokenizer.sprite3

cattorch does not train tokenizers itself, the classes only transpile an existing HuggingFace tokenizer. To train a small BPE tokenizer from scratch, use the tokenizers library directly and wrap the result:

from tokenizers import Tokenizer, models, trainers
from transformers import PreTrainedTokenizerFast
from cattorch import BPETokenizer

corpus = ["the cat sat on the mat", "the dog sat on the log"]

tok = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=100, min_frequency=1, special_tokens=[])
tok.train_from_iterator(corpus, trainer=trainer)

# no pre-tokenizer: BPE operates on the raw input string, including spaces
BPETokenizer(PreTrainedTokenizerFast(tokenizer_object=tok)).save("my_tokenizer")

The generated sprite has two top-level block stacks:

  • Encode: reads the input variable (a string) and writes token IDs to the token_ids list.
  • Decode: reads the token_ids list and writes the decoded string to the output variable.

Token IDs are 0-based, matching PyTorch embedding conventions, so the output of the encode stack can be fed directly into a transpiled model. If you don't care about the tokenizer type, use transpile_tokenizer(tokenizer, name) and cattorch will pick BPETokenizer or CharTokenizer based on the tokenizer's backend.

Scratch limits

  • Project size: Scratch limits projects to 5 MB. cattorch warns at 4 MB and errors at 5 MB.
  • List length: Scratch lists can hold at most 200,000 items. cattorch raises an error if any weight tensor or intermediate list exceeds this.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cattorch-0.3.0.tar.gz (62.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cattorch-0.3.0-py3-none-any.whl (84.6 kB view details)

Uploaded Python 3

File details

Details for the file cattorch-0.3.0.tar.gz.

File metadata

  • Download URL: cattorch-0.3.0.tar.gz
  • Upload date:
  • Size: 62.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cattorch-0.3.0.tar.gz
Algorithm Hash digest
SHA256 b9b49c530e7990b3655174d305c65754cddc18cc45fe3aa65159174bf0fecbcf
MD5 1cf7b9e0fb5ea4f5aa4316addf8000cd
BLAKE2b-256 5479b25005076862481a55d5777d1d3fc35fe51e15b3754329d13abf726bd398

See more details on using hashes here.

Provenance

The following attestation bundles were made for cattorch-0.3.0.tar.gz:

Publisher: publish.yml on NormallyNormal/cattorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cattorch-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: cattorch-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 84.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for cattorch-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 20278bf6b9ca52125a37d0c78d95eb8998b567a57555b8ebb58b6565cb15371a
MD5 303546518b09d7fdf3800773d0c1f0a0
BLAKE2b-256 32c61e9d7578d3ea3d65a56b6ee491ebfb5845d22d0c978a6024c421edcbda6f

See more details on using hashes here.

Provenance

The following attestation bundles were made for cattorch-0.3.0-py3-none-any.whl:

Publisher: publish.yml on NormallyNormal/cattorch

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page