Skip to main content

Python model implementations for the SOMA network, matching the Rust runtime exactly

Project description

SOMA Models

Python implementations of SOMA network models. These implementations are numerically identical to the Rust runtime — weights trained in Python produce the same outputs when evaluated on-chain.

Both PyTorch and Flax (JAX) are supported as first-class frameworks. Models are serialized to safetensors format, which is the canonical weight exchange format between Python and the Rust runtime.

Install

# PyTorch
uv add soma-models[torch]

# Flax / JAX
uv add soma-models[flax]

# Both
uv add soma-models[all]

Or with pip:

pip install soma-models[torch]   # PyTorch
pip install soma-models[flax]    # Flax / JAX
pip install soma-models[all]     # Both

Versioning

Model architectures are versioned. Each version defines a fixed architecture, hyperparameters, data contract, and scoring function. The on-chain runtime selects the architecture version when evaluating a model, so your weights must match the version you registered with.

New versions may be introduced via protocol upgrades. Previous versions continue to work for models registered under them.


V1

V1 is a pre-norm byte-level transformer. It operates directly on raw bytes — no external tokenizer is needed. The model uses rotary positional embeddings (RoPE), GELU activations, and a next-token prediction objective with a Gaussian uniformity regularizer (SIGReg) to prevent embedding collapse.

Architecture

Input bytes (0–255)
    │
    ▼
Embedding (vocab_size → embedding_dim)
    │
    ▼
Encoder (num_layers × TransformerBlock)
    │   ┌─────────────────────────────────┐
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Multi-Head Attention (RoPE)    │
    │   │  Dropout + Residual             │
    │   │  Pre-Norm (LayerNorm)           │
    │   │  Feed-Forward (GELU)            │
    │   │  Dropout + Residual             │
    │   └─────────────────────────────────┘
    │
    ▼
Final LayerNorm → representations (used for embedding + loss)
    │
    ▼
Linear predictor → logits (used for cross-entropy loss)

Hyperparameters

Parameter Value Description
EMBEDDING_DIM 2048 Dimension of token embeddings and hidden states
NUM_HEADS 8 Number of attention heads (head_dim = 256)
NUM_LAYERS 24 Number of transformer blocks
MAX_SEQ_LEN 1024 Maximum sequence length during on-chain evaluation
PWFF_HIDDEN_DIM 8192 Feed-forward inner dimension (4 × embedding_dim)
VOCAB_SIZE 264 256 byte tokens + 8 special tokens
MAX_WAVELENGTH 10,000 RoPE positional encoding wavelength
SCALE_FACTOR 1.0 RoPE scale factor
BATCH_SIZE 16 Batch size during on-chain evaluation

Data Contract

The model operates on raw bytes. During on-chain evaluation, data is processed as follows:

  • Each byte (0–255) is its own token ID
  • Special tokens: PAD = 256, EOS = 257
  • Data is chunked into non-overlapping sequences of MAX_SEQ_LEN (1024) bytes
  • EOS is only placed on the final chunk and only if it is shorter than MAX_SEQ_LEN — it occupies the position immediately after the last data byte. If data length is an exact multiple of MAX_SEQ_LEN, no EOS is appended
  • Any remaining positions after EOS (or after data if no EOS) are filled with PAD
  • Targets are the input token IDs shifted left by 1 (next-token prediction), with PAD appended as the final target
  • Position IDs are global byte offsets for data positions. PAD and EOS positions are clamped to the offset of the last data byte + 1 (they do not continue incrementing)
  • Sequences are batched in groups of BATCH_SIZE (16)

You are free to prepare your training data however you want — different sequence lengths, different batching, different shuffling. But your model will be scored using the contract above, so your training should produce weights that perform well under these conditions.

Scoring (Loss Function)

Models are scored on-chain by the following loss:

loss = cross_entropy + sig_reg_loss

The model with the lowest loss wins. Both components are:

  1. Cross-entropy loss: Standard next-token prediction loss over the vocabulary. PAD tokens (256) are masked out and do not contribute to the loss.

  2. SIGReg loss: A Gaussian uniformity regularizer (LeJEPA) that penalizes embedding collapse. It measures how far the embedding distribution deviates from a standard Gaussian by comparing the characteristic function of projected representations against the Gaussian characteristic function.

SIGReg Parameter Value
SIG_REG_T_MAX 3.0
SIG_REG_SLICES 256
SIG_REG_POINTS 17
SIG_REG_COEFFICIENT 0.02

SIGReg noise is generated using each framework's native RNG (jax.random via nnx.Rngs for Flax, torch.randn for PyTorch).

Tokenizer

The tokenizer implements the on-chain data contract as a framework-agnostic Python module. It converts raw bytes into a list of TokenizedSequence items with token_ids, targets, and pos_ids that can be wrapped with any framework's tensor constructor (torch.tensor(), jnp.array(), tf.constant(), etc.).

from soma_models.v1.tokenizer import tokenize

sequences = tokenize(raw_bytes)
for seq in sequences:
    seq.token_ids   # [seq_len] list of ints
    seq.targets     # [seq_len] list of ints
    seq.pos_ids     # [seq_len] list of ints

The default max_seq_len matches the on-chain evaluation parameter. You can override it for training:

sequences = tokenize(raw_bytes, max_seq_len=512)

Usage

Both frameworks expose the same API: a Model, a SIGReg regularizer, and a compute_loss function.

PyTorch

import torch
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.torch.modules.model import Model
from soma_models.v1.torch.modules.sig_reg import SIGReg
from soma_models.v1.torch.loss import compute_loss

# Initialize
model = Model(ModelConfig(dropout_rate=0.1))
sig_reg = SIGReg(SIGRegConfig())

# Tokenize raw bytes
sequences = tokenize(raw_bytes)

# Forward + loss (differentiable)
for seq in sequences:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=torch.tensor([seq.token_ids]),
        targets=torch.tensor([seq.targets]),
    )
    loss.backward()

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

import jax.numpy as jnp
from flax import nnx
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.flax.modules.model import Model
from soma_models.v1.flax.modules.sig_reg import SIGReg
from soma_models.v1.flax.loss import compute_loss

# Initialize
rngs = nnx.Rngs(0)
model = Model(ModelConfig(dropout_rate=0.1), rngs=rngs)
sig_reg = SIGReg(SIGRegConfig(), rngs=rngs)

# Tokenize raw bytes
sequences = tokenize(raw_bytes)

# Forward + loss (differentiable via jax.grad)
for seq in sequences:
    loss, embedding = compute_loss(
        model, sig_reg,
        token_ids=jnp.array([seq.token_ids]),
        targets=jnp.array([seq.targets]),
    )

# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=rngs)

Weight Serialization

Weights are stored in safetensors format with a canonical key layout. The serde layer handles all framework-specific transformations automatically:

  • LayerNorm: weight/bias (torch) ↔ gamma/beta (safetensors) ↔ scale/bias (flax)
  • Linear: Row-major (torch) ↔ column-major (safetensors/flax)
  • Attention: Split-head (flax) ↔ flat (safetensors/torch)

PyTorch

from soma_models.v1.configs import ModelConfig
from soma_models.v1.torch.modules.model import Model

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))

Flax

from soma_models.v1.configs import ModelConfig
from soma_models.v1.flax.modules.model import Model
from flax import nnx

# Save
model.save("weights.safetensors")

# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=nnx.Rngs(0))

Weights are cross-compatible — you can save from one framework and load into the other.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

soma_models-0.1.7.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

soma_models-0.1.7-py3-none-any.whl (25.1 kB view details)

Uploaded Python 3

File details

Details for the file soma_models-0.1.7.tar.gz.

File metadata

  • Download URL: soma_models-0.1.7.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.7.tar.gz
Algorithm Hash digest
SHA256 88fba72e15b7a50709b0e46c353547f562298ad35903f0f7e37e90ec1c1cd14e
MD5 6ff2b0e6b8f4c168ebbce9363cec2f3d
BLAKE2b-256 a5dc95ca0a98af7afac013cf6db1e68d4477344e0da7f3d7a849ebe659d22513

See more details on using hashes here.

File details

Details for the file soma_models-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: soma_models-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 25.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.10.7 {"installer":{"name":"uv","version":"0.10.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for soma_models-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 24b7fe04726be1c95235dd56a407bc8e263dbc85b44e673bcc042dbdd30482d7
MD5 983fa08ded8d90c2cbedaf5da4eeeee3
BLAKE2b-256 be301c06c3be974ba848080ebd5dab732cb53d30adcd1fafe4eacf27cc5d77ed

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page