Tune the initial recurrent state of hybrid models. Zero inference overhead.
Project description
S₀ Tuning
S₀ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models
Jack Young
S₀ Tuning optimizes the initial hidden state (S₀) of recurrent layers in hybrid architectures (GatedDeltaNet, Mamba-2). The learned states are injected before each forward pass, adding zero latency at inference since recurrent models already maintain state. Trained states are 48 MB, training takes ~3 minutes on one GPU.
Results
All results on Qwen3.5-4B with 20 optimization steps per task. p-values from Welch's t-test (unequal variances) across independent seed runs.
| Benchmark | Base | + S₀ Tuning | Delta | Seeds | p-value |
|---|---|---|---|---|---|
| HumanEval | 48.8% | 72.2% | +23.6pp | 10 | < 10⁻¹¹ |
| MATH-500 | 51.4% | 56.2% | +4.8pp | 8 | 0.00002 |
| GSM8K | 85.3% | 88.1% | +2.8pp | 10 | 0.0003 |
| Spider (boundary test) | ~72% | ~72% | +0.0pp | 5 alphas | n.s. |
The Spider result (no improvement on out-of-distribution SQL) supports the trajectory-steering mechanism: S₀ Tuning biases the model toward solution trajectories already in its distribution, rather than injecting new knowledge.
Scaling. Gains increase with model size: +2.7pp at 0.8B, +23.6pp at 4B, +44.0pp at 9B (HumanEval).
Cross-architecture. On FalconH1-7B (Mamba-2): 71.8% vs 71.4% LoRA baseline (experimental).
Usage
from s0_tuning import S0Trainer
trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")
prompt = "Q: What is 2+2?\nA:"
answer = " 4"
full_text = prompt + answer
# prompt_length = number of prompt tokens (for completion-only loss masking)
tokens = trainer.tokenizer(prompt)
prompt_length = len(tokens["input_ids"])
data = [(full_text, prompt_length)]
trainer.train(data)
trainer.activate()
output = trainer.generate("Q: What is 2+2?\nA:")
print(output)
trainer.save("./my_s0_states")
Loading saved states on a new instance:
trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")
trainer.load("./my_s0_states")
trainer.activate()
output = trainer.generate("Q: What is 2+2?\nA:")
Installation
pip install git+https://github.com/jackyoung27/s0-tuning.git
Requires PyTorch 2.0+, transformers >= 4.51.0, and a GPU with >= 10 GB VRAM for training.
Supported Models
| Model Family | Architecture | Recurrent Layers | Default Alpha | Status |
|---|---|---|---|---|
| Qwen3.5 | GatedDeltaNet | linear_attention layers |
0.07 | Tested |
| FalconH1 | Mamba-2 | mamba layers |
0.65 | Experimental |
To add a new hybrid architecture, the model needs (1) identifiable recurrent
layers and (2) an initial_state argument in the recurrent kernel. See
_detect_architecture and _patch_gdn / _patch_mamba2 in trainer.py.
How It Works
Recurrent models process sequences starting from S₀ = 0. S₀ Tuning creates
learnable state tensors, patches them into the model as initial_state, and
optimizes via next-token prediction loss on correct completions. At inference,
learned states are scaled by alpha (e.g. 0.07 for GatedDeltaNet) to prevent
distribution shift. The result: the model starts each sequence from a
task-informed state rather than zeros, biasing generation toward correct
solution trajectories without modifying any model weights.
Configuration
S0Config controls training:
| Parameter | Default | Description |
|---|---|---|
n_steps |
20 | Optimization steps |
lr |
1e-3 | Learning rate |
l2_lambda |
5e-4 | L2 regularization weight |
alpha |
None | Scaling factor (auto-detected per architecture if None) |
normalize |
False | Normalize states before scaling |
grad_clip |
1.0 | Gradient clipping norm |
max_length |
2048 | Maximum sequence length |
Citation
If you use this codebase, or otherwise found our work valuable, please cite:
@article{young2026s0tuning,
title={S$_0$ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models},
author={Young, Jack},
year={2026}
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file s0_tuning-0.1.0.tar.gz.
File metadata
- Download URL: s0_tuning-0.1.0.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
01246e30fa88f22f5c4687ea176e419a22b10a2f22a98255d16ec03aad06e8ed
|
|
| MD5 |
3a2cc3426dbcd99e612d537e8977cac0
|
|
| BLAKE2b-256 |
b46f46695a4320375e50e42807ac850e613ac37b81712bf813b89c4e763603a8
|
File details
Details for the file s0_tuning-0.1.0-py3-none-any.whl.
File metadata
- Download URL: s0_tuning-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e4e97acf7616dc7860f9982ae920d587451e92995f2868473ed4419654be389f
|
|
| MD5 |
f9cf027ede39ae76fe56996e64a6dba2
|
|
| BLAKE2b-256 |
9da216fa5efb3229d3442c29aa3daf36573f37178f5a999253bc8a1f1f2edce6
|