Skip to main content

Python model implementations for the SOMA network, matching the Rust runtime exactly

Project description

SOMA Models

Python implementations of SOMA network models. These implementations are numerically identical to the Rust runtime — weights trained in Python produce the same outputs when evaluated on-chain.

Both PyTorch and Flax (JAX) are supported as first-class frameworks. Models are serialized to safetensors format, which is the canonical weight exchange format between Python and the Rust runtime.

Install

# PyTorch
uv add soma-models[torch]

# Flax / JAX
uv add soma-models[flax]

# Both
uv add soma-models[all]

Or with pip:

pip install soma-models[torch]   # PyTorch
pip install soma-models[flax]    # Flax / JAX
pip install soma-models[all]     # Both

Versioning

Model architectures are versioned. Each version defines a fixed architecture, hyperparameters, data contract, and scoring function. The on-chain runtime selects the architecture version when evaluating a model, so your weights must match the version you registered with.

New versions may be introduced via protocol upgrades. Previous versions continue to work for models registered under them.


V1

V1 is a pre-norm byte-level transformer. It operates directly on raw bytes — no external tokenizer is needed. The model uses rotary positional embeddings (RoPE), GELU activations, and a next-token prediction objective with a Gaussian uniformity regularizer (SIGReg) to prevent embedding collapse.

Architecture

Input bytes (0–255)
    │
    ▼
Embedding (vocab_size → embedding_dim)
    │
    ▼
Encoder (num_layers × TransformerBlock)
    │   ┌─────────────────────────────────┐
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Multi-Head Attention (RoPE)    │
    │   │  Dropout + Residual             │
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Feed-Forward (GELU)            │
    │   │  Dropout + Residual             │
    │   └─────────────────────────────────┘
    │
    ▼
Final LayerNorm → representations (used for embedding + loss)
    │
    ▼
Linear predictor → logits (used for cross-entropy loss)

Hyperparameters

Parameter Value Description
EMBEDDING_DIM 2048 Dimension of token embeddings and hidden states
NUM_HEADS 8 Number of attention heads (head_dim = 256)
NUM_LAYERS 32 Number of transformer blocks
MAX_SEQ_LEN 8192 Maximum sequence length during on-chain evaluation
PWFF_HIDDEN_DIM 8192 Feed-forward inner dimension (4 × embedding_dim)
VOCAB_SIZE 264 256 byte tokens + 8 special tokens
MAX_WAVELENGTH 10,000 RoPE positional encoding wavelength
SCALE_FACTOR 1.0 RoPE scale factor
BATCH_SIZE 32 Batch size during on-chain evaluation

Data Contract

The model operates on raw bytes. During on-chain evaluation, data is processed as follows:

  • Each byte (0–255) is its own token ID
  • Special tokens: PAD = 256, EOS = 257
  • Data is chunked into non-overlapping sequences of MAX_SEQ_LEN (8192) bytes
  • EOS is only placed on the final chunk and only if it is shorter than MAX_SEQ_LEN — it occupies the position immediately after the last data byte. If data length is an exact multiple of MAX_SEQ_LEN, no EOS is appended
  • Any remaining positions after EOS (or after data if no EOS) are filled with PAD
  • Targets are the input token IDs shifted left by 1 (next-token prediction), with PAD appended as the final target
  • Position IDs are global byte offsets for data positions. PAD and EOS positions are clamped to the offset of the last data byte + 1 (they do not continue incrementing)
  • Sequences are batched in groups of BATCH_SIZE (32)

You are free to prepare your training data however you want — different sequence lengths, different batching, different shuffling. But your model will be scored using the contract above, so your training should produce weights that perform well under these conditions.

Scoring (Loss Function)

Models are scored on-chain by the following loss:

loss = cross_entropy + sig_reg_loss

The model with the lowest loss wins. Both components are:

  1. Cross-entropy loss: Standard next-token prediction loss over the vocabulary. PAD tokens (256) are masked out and do not contribute to the loss.

  2. SIGReg loss: A Gaussian uniformity regularizer (LeJEPA) that penalizes embedding collapse. It measures how far the embedding distribution deviates from a standard Gaussian by comparing the characteristic function of projected representations against the Gaussian characteristic function.

SIGReg Parameter Value
SIG_REG_T_MAX 3.0
SIG_REG_SLICES 1024
SIG_REG_POINTS 17
SIG_REG_COEFFICIENT 0.02

SIGReg noise is generated using each framework's native RNG (jax.random via nnx.Rngs for Flax, torch.randn for PyTorch).

Tokenizer

The tokenizer implements the on-chain data contract as a framework-agnostic Python module. It converts raw bytes into token_ids, targets, and pos_ids that can be wrapped with any framework's tensor constructor (torch.tensor(), jnp.array(), tf.constant(), etc.).

from soma_models.v1.tokenizer import tokenize

batches = tokenize(raw_bytes)
for batch in batches:
    batch.token_ids   # [batch, seq_len] nested list of ints
    batch.targets     # [batch, seq_len] nested list of ints
    batch.pos_ids     # [batch, seq_len] nested list of ints

The default max_seq_len and batch_size match the on-chain evaluation parameters. You can override them for training:

batches = tokenize(raw_bytes, max_seq_len=2048, batch_size=8)

The final batch may contain fewer than batch_size sequences (matching the Rust DataLoader behaviour).

Usage

Both frameworks expose the same API: a Model, a SIGReg regularizer, and a compute_loss function.

PyTorch

import torch
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.torch.modules.model import Model
from soma_models.v1.torch.modules.sig_reg import SIGReg
from soma_models.v1.torch.loss import compute_loss

# Initialize
model = Model(ModelConfig(dropout_rate=0.1))
sig_reg = SIGReg(SIGRegConfig())

# Tokenize raw bytes
batches = tokenize(raw_bytes)

# Forward + loss (differentiable)
for batch in batches:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=torch.tensor(batch.token_ids),
        targets=torch.tensor(batch.targets),
    )
    loss.backward()

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

import jax.numpy as jnp
from flax import nnx
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.flax.modules.model import Model
from soma_models.v1.flax.modules.sig_reg import SIGReg
from soma_models.v1.flax.loss import compute_loss

# Initialize
rngs = nnx.Rngs(0)
model = Model(ModelConfig(dropout_rate=0.1), rngs=rngs)
sig_reg = SIGReg(SIGRegConfig(), rngs=rngs)

# Tokenize raw bytes
batches = tokenize(raw_bytes)

# Forward + loss (differentiable via jax.grad)
for batch in batches:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=jnp.array(batch.token_ids),
        targets=jnp.array(batch.targets),
    )

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=rngs)

Weight Serialization

Weights are stored in safetensors format with a canonical key layout. The serde layer handles all framework-specific transformations automatically:

  • LayerNorm: weight/bias (torch) ↔ gamma/beta (safetensors) ↔ scale/bias (flax)
  • Linear: Row-major (torch) ↔ column-major (safetensors/flax)
  • Attention: Split-head (flax) ↔ flat (safetensors/torch)

PyTorch

from soma_models.v1.configs import ModelConfig
from soma_models.v1.torch.modules.model import Model

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

from soma_models.v1.configs import ModelConfig
from soma_models.v1.flax.modules.model import Model
from flax import nnx

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=nnx.Rngs(0))

Weights are cross-compatible — you can save from one framework and load into the other.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

soma_models-0.1.2.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

soma_models-0.1.2-py3-none-any.whl (25.4 kB view details)

Uploaded Python 3

File details

Details for the file soma_models-0.1.2.tar.gz.

File metadata

  • Download URL: soma_models-0.1.2.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.2.tar.gz
Algorithm Hash digest
SHA256 eef90573963e58f4da7350bf5120919fa582f22f1cad97bac55028c868824603
MD5 4d58a0b6b224598c469822ff5a8b2eb5
BLAKE2b-256 239c20d00f4142eb7f634c99af6d7a4b987bd9396006c112a470712ad8bc9a19

See more details on using hashes here.

File details

Details for the file soma_models-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: soma_models-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 25.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0c6e7152569f715976ec8af0b24a711e13580744cb167b749c4dd759c6c71ba5
MD5 b4a0963ea4eba9dec3d24644c5d71a78
BLAKE2b-256 08210d3b2daf8302b75e1d4377f0569ec9af418e1e5bda77fa3a6e5c2622f0b5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page