Skip to main content

Python model implementations for the SOMA network, matching the Rust runtime exactly

Project description

SOMA Models

Python implementations of SOMA network models. These implementations are numerically identical to the Rust runtime — weights trained in Python produce the same outputs when evaluated on-chain.

Both PyTorch and Flax (JAX) are supported as first-class frameworks. Models are serialized to safetensors format, which is the canonical weight exchange format between Python and the Rust runtime.

Install

# PyTorch
uv add soma-models[torch]

# Flax / JAX
uv add soma-models[flax]

# Both
uv add soma-models[all]

Or with pip:

pip install soma-models[torch]   # PyTorch
pip install soma-models[flax]    # Flax / JAX
pip install soma-models[all]     # Both

Versioning

Model architectures are versioned. Each version defines a fixed architecture, hyperparameters, data contract, and scoring function. The on-chain runtime selects the architecture version when evaluating a model, so your weights must match the version you registered with.

New versions may be introduced via protocol upgrades. Previous versions continue to work for models registered under them.


V1

V1 is a pre-norm byte-level transformer. It operates directly on raw bytes — no external tokenizer is needed. The model uses rotary positional embeddings (RoPE), GELU activations, and a next-token prediction objective with a Gaussian uniformity regularizer (SIGReg) to prevent embedding collapse.

Architecture

Input bytes (0–255)
    │
    ▼
Embedding (vocab_size → embedding_dim)
    │
    ▼
Encoder (num_layers × TransformerBlock)
    │   ┌─────────────────────────────────┐
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Multi-Head Attention (RoPE)    │
    │   │  Dropout + Residual             │
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Feed-Forward (GELU)            │
    │   │  Dropout + Residual             │
    │   └─────────────────────────────────┘
    │
    ▼
Final LayerNorm → representations (used for embedding + loss)
    │
    ▼
Linear predictor → logits (used for cross-entropy loss)

Hyperparameters

Parameter Value Description
EMBEDDING_DIM 2048 Dimension of token embeddings and hidden states
NUM_HEADS 8 Number of attention heads (head_dim = 256)
NUM_LAYERS 24 Number of transformer blocks
MAX_SEQ_LEN 1024 Maximum sequence length during on-chain evaluation
PWFF_HIDDEN_DIM 8192 Feed-forward inner dimension (4 × embedding_dim)
VOCAB_SIZE 264 256 byte tokens + 8 special tokens
MAX_WAVELENGTH 10,000 RoPE positional encoding wavelength
SCALE_FACTOR 1.0 RoPE scale factor
BATCH_SIZE 16 Batch size during on-chain evaluation

Data Contract

The model operates on raw bytes. During on-chain evaluation, data is processed as follows:

  • Each byte (0–255) is its own token ID
  • Special tokens: PAD = 256, EOS = 257
  • Data is chunked into non-overlapping sequences of MAX_SEQ_LEN (1024) bytes
  • EOS is only placed on the final chunk and only if it is shorter than MAX_SEQ_LEN — it occupies the position immediately after the last data byte. If data length is an exact multiple of MAX_SEQ_LEN, no EOS is appended
  • Any remaining positions after EOS (or after data if no EOS) are filled with PAD
  • Targets are the input token IDs shifted left by 1 (next-token prediction), with PAD appended as the final target
  • Position IDs are global byte offsets for data positions. PAD and EOS positions are clamped to the offset of the last data byte + 1 (they do not continue incrementing)
  • Sequences are batched in groups of BATCH_SIZE (16)

You are free to prepare your training data however you want — different sequence lengths, different batching, different shuffling. But your model will be scored using the contract above, so your training should produce weights that perform well under these conditions.

Scoring (Loss Function)

Models are scored on-chain by the following loss:

loss = cross_entropy + sig_reg_loss

The model with the lowest loss wins. Both components are:

  1. Cross-entropy loss: Standard next-token prediction loss over the vocabulary. PAD tokens (256) are masked out and do not contribute to the loss.

  2. SIGReg loss: A Gaussian uniformity regularizer (LeJEPA) that penalizes embedding collapse. It measures how far the embedding distribution deviates from a standard Gaussian by comparing the characteristic function of projected representations against the Gaussian characteristic function.

SIGReg Parameter Value
SIG_REG_T_MAX 3.0
SIG_REG_SLICES 256
SIG_REG_POINTS 17
SIG_REG_COEFFICIENT 0.02

SIGReg noise is generated using each framework's native RNG (jax.random via nnx.Rngs for Flax, torch.randn for PyTorch).

Tokenizer

The tokenizer implements the on-chain data contract as a framework-agnostic Python module. It converts raw bytes into token_ids, targets, and pos_ids that can be wrapped with any framework's tensor constructor (torch.tensor(), jnp.array(), tf.constant(), etc.).

from soma_models.v1.tokenizer import tokenize

batches = tokenize(raw_bytes)
for batch in batches:
    batch.token_ids   # [batch, seq_len] nested list of ints
    batch.targets     # [batch, seq_len] nested list of ints
    batch.pos_ids     # [batch, seq_len] nested list of ints

The default max_seq_len and batch_size match the on-chain evaluation parameters. You can override them for training:

batches = tokenize(raw_bytes, max_seq_len=512, batch_size=32)

The final batch may contain fewer than batch_size sequences (matching the Rust DataLoader behaviour).

Usage

Both frameworks expose the same API: a Model, a SIGReg regularizer, and a compute_loss function.

PyTorch

import torch
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.torch.modules.model import Model
from soma_models.v1.torch.modules.sig_reg import SIGReg
from soma_models.v1.torch.loss import compute_loss

# Initialize
model = Model(ModelConfig(dropout_rate=0.1))
sig_reg = SIGReg(SIGRegConfig())

# Tokenize raw bytes
batches = tokenize(raw_bytes)

# Forward + loss (differentiable)
for batch in batches:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=torch.tensor(batch.token_ids),
        targets=torch.tensor(batch.targets),
    )
    loss.backward()

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

import jax.numpy as jnp
from flax import nnx
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.flax.modules.model import Model
from soma_models.v1.flax.modules.sig_reg import SIGReg
from soma_models.v1.flax.loss import compute_loss

# Initialize
rngs = nnx.Rngs(0)
model = Model(ModelConfig(dropout_rate=0.1), rngs=rngs)
sig_reg = SIGReg(SIGRegConfig(), rngs=rngs)

# Tokenize raw bytes
batches = tokenize(raw_bytes)

# Forward + loss (differentiable via jax.grad)
for batch in batches:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=jnp.array(batch.token_ids),
        targets=jnp.array(batch.targets),
    )

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=rngs)

Weight Serialization

Weights are stored in safetensors format with a canonical key layout. The serde layer handles all framework-specific transformations automatically:

  • LayerNorm: weight/bias (torch) ↔ gamma/beta (safetensors) ↔ scale/bias (flax)
  • Linear: Row-major (torch) ↔ column-major (safetensors/flax)
  • Attention: Split-head (flax) ↔ flat (safetensors/torch)

PyTorch

from soma_models.v1.configs import ModelConfig
from soma_models.v1.torch.modules.model import Model

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

from soma_models.v1.configs import ModelConfig
from soma_models.v1.flax.modules.model import Model
from flax import nnx

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=nnx.Rngs(0))

Weights are cross-compatible — you can save from one framework and load into the other.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

soma_models-0.1.4.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

soma_models-0.1.4-py3-none-any.whl (25.4 kB view details)

Uploaded Python 3

File details

Details for the file soma_models-0.1.4.tar.gz.

File metadata

  • Download URL: soma_models-0.1.4.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.6 {"installer":{"name":"uv","version":"0.10.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.4.tar.gz
Algorithm Hash digest
SHA256 a2fd56a13e2d99f1816fb90d33855376a56e35cd7ef66bb247f5001f15efa4fe
MD5 f4a9517a5e4b261f7a7ede3c0e2c4a25
BLAKE2b-256 a445ac4c13a8b6225f4ee209c500ce5d26172cb8e04b849763385f3fbf563473

See more details on using hashes here.

File details

Details for the file soma_models-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: soma_models-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 25.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.6 {"installer":{"name":"uv","version":"0.10.6","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 1c984f4e25653ed831568e9381191fb76a13d3e4f95f831137128f4d29180443
MD5 8b089d6c2f62fc21f95008837f9fa04a
BLAKE2b-256 cf2e0ace9e3354d8bfe45869a76305fa15733f8686a9bf2fdbf64cccb38cbaa4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page