Python model implementations for the Soma network, matching the Rust runtime exactly
Project description
Soma Models
Python implementations of Soma network models. These implementations are numerically identical to the Rust runtime — weights trained in Python produce the same outputs when evaluated on-chain.
Both PyTorch and Flax (JAX) are supported as first-class frameworks. Models are serialized to safetensors format, which is the canonical weight exchange format between Python and the Rust runtime.
Install
# PyTorch
uv add soma-models[torch]
# Flax / JAX
uv add soma-models[flax]
# Both
uv add soma-models[all]
Or with pip:
pip install soma-models[torch] # PyTorch
pip install soma-models[flax] # Flax / JAX
pip install soma-models[all] # Both
Versioning
Model architectures are versioned. Each version defines a fixed architecture, hyperparameters, data contract, and scoring function. The on-chain runtime selects the architecture version when evaluating a model, so your weights must match the version you registered with.
New versions may be introduced via protocol upgrades. Previous versions continue to work for models registered under them.
V1
V1 is a pre-norm byte-level transformer. It operates directly on raw bytes — no external tokenizer is needed. The model uses rotary positional embeddings (RoPE), GELU activations, and a next-token prediction objective with a Gaussian uniformity regularizer (SIGReg) to prevent embedding collapse.
Architecture
Input bytes (0–255)
│
▼
Embedding (vocab_size → embedding_dim)
│
▼
Encoder (num_layers × TransformerBlock)
│ ┌─────────────────────────────────┐
│ │ Pre-Norm (LayerNorm) │
│ │ Multi-Head Attention (RoPE) │
│ │ Dropout + Residual │
│ │ Pre-Norm (LayerNorm) │
│ │ Feed-Forward (GELU) │
│ │ Dropout + Residual │
│ └─────────────────────────────────┘
│
▼
Final LayerNorm → representations (used for embedding + loss)
│
▼
Linear predictor → logits (used for cross-entropy loss)
Hyperparameters
| Parameter | Value | Description |
|---|---|---|
EMBEDDING_DIM |
2048 | Dimension of token embeddings and hidden states |
NUM_HEADS |
8 | Number of attention heads (head_dim = 256) |
NUM_LAYERS |
32 | Number of transformer blocks |
MAX_SEQ_LEN |
8192 | Maximum sequence length during on-chain evaluation |
PWFF_HIDDEN_DIM |
8192 | Feed-forward inner dimension (4 × embedding_dim) |
VOCAB_SIZE |
264 | 256 byte tokens + 8 special tokens |
MAX_WAVELENGTH |
10,000 | RoPE positional encoding wavelength |
SCALE_FACTOR |
1.0 | RoPE scale factor |
BATCH_SIZE |
32 | Batch size during on-chain evaluation |
Data Contract
The model operates on raw bytes. During on-chain evaluation, data is processed as follows:
- Each byte (0–255) is its own token ID
- Special tokens: PAD = 256, EOS = 257
- Data is chunked into non-overlapping sequences of
MAX_SEQ_LEN(8192) bytes - EOS is only placed on the final chunk and only if it is shorter than
MAX_SEQ_LEN— it occupies the position immediately after the last data byte. If data length is an exact multiple ofMAX_SEQ_LEN, no EOS is appended - Any remaining positions after EOS (or after data if no EOS) are filled with PAD
- Targets are the input token IDs shifted left by 1 (next-token prediction), with PAD appended as the final target
- Position IDs are global byte offsets for data positions. PAD and EOS positions are clamped to the offset of the last data byte + 1 (they do not continue incrementing)
- Sequences are batched in groups of
BATCH_SIZE(32)
You are free to prepare your training data however you want — different sequence lengths, different batching, different shuffling. But your model will be scored using the contract above, so your training should produce weights that perform well under these conditions.
Scoring (Loss Function)
Models are scored on-chain by the following loss:
loss = cross_entropy + sig_reg_loss
The model with the lowest loss wins. Both components are:
-
Cross-entropy loss: Standard next-token prediction loss over the vocabulary. PAD tokens (256) are masked out and do not contribute to the loss.
-
SIGReg loss: A Gaussian uniformity regularizer (LeJEPA) that penalizes embedding collapse. It measures how far the embedding distribution deviates from a standard Gaussian by comparing the characteristic function of projected representations against the Gaussian characteristic function.
| SIGReg Parameter | Value |
|---|---|
SIG_REG_T_MAX |
3.0 |
SIG_REG_SLICES |
1024 |
SIG_REG_POINTS |
17 |
SIG_REG_COEFFICIENT |
0.02 |
SIGReg noise is generated using each framework's native RNG (jax.random via nnx.Rngs for Flax, torch.randn for PyTorch).
Tokenizer
The tokenizer implements the on-chain data contract as a framework-agnostic Python module. It converts raw bytes into token_ids, targets, and pos_ids that can be wrapped with any framework's tensor constructor (torch.tensor(), jnp.array(), tf.constant(), etc.).
from soma_models.v1.tokenizer import tokenize
batches = tokenize(raw_bytes)
for batch in batches:
batch.token_ids # [batch, seq_len] nested list of ints
batch.targets # [batch, seq_len] nested list of ints
batch.pos_ids # [batch, seq_len] nested list of ints
The default max_seq_len and batch_size match the on-chain evaluation parameters. You can override them for training:
batches = tokenize(raw_bytes, max_seq_len=2048, batch_size=8)
The final batch may contain fewer than batch_size sequences (matching the Rust DataLoader behaviour).
Usage
Both frameworks expose the same API: a Model, a SIGReg regularizer, and a compute_loss function.
PyTorch
import torch
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.torch.modules.model import Model
from soma_models.v1.torch.modules.sig_reg import SIGReg
from soma_models.v1.torch.loss import compute_loss
# Initialize
model = Model(ModelConfig(dropout_rate=0.1))
sig_reg = SIGReg(SIGRegConfig())
# Tokenize raw bytes
batches = tokenize(raw_bytes)
# Forward + loss (differentiable)
for batch in batches:
loss, embedding = compute_loss(
model, sig_reg,
token_ids=torch.tensor(batch.token_ids),
targets=torch.tensor(batch.targets),
)
loss.backward()
# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))
Flax
import jax.numpy as jnp
from flax import nnx
from soma_models.v1.configs import ModelConfig, SIGRegConfig
from soma_models.v1.tokenizer import tokenize
from soma_models.v1.flax.modules.model import Model
from soma_models.v1.flax.modules.sig_reg import SIGReg
from soma_models.v1.flax.loss import compute_loss
# Initialize
rngs = nnx.Rngs(0)
model = Model(ModelConfig(dropout_rate=0.1), rngs=rngs)
sig_reg = SIGReg(SIGRegConfig(), rngs=rngs)
# Tokenize raw bytes
batches = tokenize(raw_bytes)
# Forward + loss (differentiable via jax.grad)
for batch in batches:
loss, embedding = compute_loss(
model, sig_reg,
token_ids=jnp.array(batch.token_ids),
targets=jnp.array(batch.targets),
)
# Save / load weights
model.save("weights.safetensors")
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=rngs)
Weight Serialization
Weights are stored in safetensors format with a canonical key layout. The serde layer handles all framework-specific transformations automatically:
- LayerNorm:
weight/bias(torch) ↔gamma/beta(safetensors) ↔scale/bias(flax) - Linear: Row-major (torch) ↔ column-major (safetensors/flax)
- Attention: Split-head (flax) ↔ flat (safetensors/torch)
PyTorch
from soma_models.v1.configs import ModelConfig
from soma_models.v1.torch.modules.model import Model
# Save
model.save("weights.safetensors")
# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0))
Flax
from soma_models.v1.configs import ModelConfig
from soma_models.v1.flax.modules.model import Model
from flax import nnx
# Save
model.save("weights.safetensors")
# Load
model = Model.load("weights.safetensors", ModelConfig(dropout_rate=0.0), rngs=nnx.Rngs(0))
Weights are cross-compatible — you can save from one framework and load into the other.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file soma_models-0.1.1.tar.gz.
File metadata
- Download URL: soma_models-0.1.1.tar.gz
- Upload date:
- Size: 15.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6cadc689e3e14afdf3e14f05cd3aa946242caf739cdc720742a148c2e7532131
|
|
| MD5 |
474edc7138c098efc384dd915d623dd6
|
|
| BLAKE2b-256 |
95e296a95360289761531dd2c7329c9c1090e6b8db5adaa90f47ade394715d1f
|
File details
Details for the file soma_models-0.1.1-py3-none-any.whl.
File metadata
- Download URL: soma_models-0.1.1-py3-none-any.whl
- Upload date:
- Size: 25.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1c15e37a418b2f11ad782f8881f25030fb7a94e4e479048bc0ae85f9bff18878
|
|
| MD5 |
6a7b9e15989b07a5d658a2fe58f1b188
|
|
| BLAKE2b-256 |
164cdbfca1708001c4a6b3daf1b2d726b8f8ab380b9d6708e8605f1139be5bc3
|