Skip to main content

Tune the initial recurrent state of hybrid models. Zero inference overhead.

Project description

S₀ Tuning

S₀ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models

Jack Young

Paper | Website

S₀ Tuning overview: learned initial states injected into recurrent layers with zero inference overhead

S₀ Tuning optimizes the initial hidden state (S₀) of recurrent layers in hybrid architectures (GatedDeltaNet, Mamba-2). The learned states are injected before each forward pass, adding zero latency at inference since recurrent models already maintain state. In our main HumanEval setting, training uses roughly 48 execution-verified solutions; trained states are 48 MB, and training takes about 3 minutes on one GPU.

Results

Main results use Qwen3.5-4B with 20 optimization steps per task. Unless noted otherwise, p-values are from two-sided Welch's t-test across independent seed runs.

Benchmark Base + S₀ Tuning Delta Seeds p-value
HumanEval 48.8% 72.2% +23.6pp 10 < 10⁻¹¹
MATH-500 51.4% 56.2% +4.8pp 8 0.00002
GSM8K 85.3% 88.1% +2.8pp 10 0.0003
Spider (boundary test) ~72% ~72% +0.0pp 5 alphas n.s.

The Spider result (no improvement on out-of-distribution SQL) supports the trajectory-steering mechanism: S₀ Tuning biases the model toward solution trajectories already in its distribution, rather than injecting new knowledge.

Scaling. Gains increase with model size: +2.6pp at 0.8B, +23.6pp at 4B, +44.0pp at 9B (HumanEval).

Cross-architecture. On FalconH1-7B (Mamba-2), S₀ reaches 71.8% vs 71.4% for LoRA in a 3-seed comparison, statistically indistinguishable at this sample size.

LoRA framing. The main Qwen comparison is against the best rank-24 LoRA baseline (+12.7pp). In a separate matched-budget comparison, rank-64 LoRA degrades by -15.5pp in this small-data regime.

Usage

from s0 import S0Trainer

trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")

prompt = "Q: What is 2+2?\nA:"
answer = " 4"
full_text = prompt + answer

# prompt_length = number of prompt tokens (for completion-only loss masking)
tokens = trainer.tokenizer(prompt)
prompt_length = len(tokens["input_ids"])

data = [(full_text, prompt_length)]
trainer.train(data)
trainer.activate()

output = trainer.generate("Q: What is 2+2?\nA:")
print(output)

trainer.save("./my_s0_states")

Loading saved states on a new instance:

trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")
trainer.load("./my_s0_states")
trainer.activate()
output = trainer.generate("Q: What is 2+2?\nA:")

Installation

pip install s0-tuning

Requires PyTorch 2.0+, transformers >= 4.51.0, and a GPU with >= 10 GB VRAM for training.

Supported Models

Model Family Architecture Recurrent Layers Default Alpha Status
Qwen3.5 GatedDeltaNet linear_attention layers 0.07 Tested
FalconH1 Mamba-2 mamba layers 0.65 Experimental

To add a new hybrid architecture, the model needs (1) identifiable recurrent layers and (2) an initial_state argument in the recurrent kernel. See _detect_architecture and _patch_gdn / _patch_mamba2 in trainer.py.

How It Works

Recurrent models process sequences starting from S₀ = 0. S₀ Tuning creates learnable state tensors, patches them into the model as initial_state, and optimizes via next-token prediction loss on correct completions. At inference, learned states are scaled by alpha (e.g. 0.07 for GatedDeltaNet) to prevent distribution shift. The result: the model starts each sequence from a task-informed state rather than zeros, biasing generation toward correct solution trajectories without modifying any model weights.

Configuration

S0Config controls training:

Parameter Default Description
n_steps 20 Optimization steps
lr 1e-3 Learning rate
l2_lambda 5e-4 L2 regularization weight
alpha None Scaling factor (auto-detected per architecture if None)
normalize False Normalize states before scaling
grad_clip 1.0 Gradient clipping norm
max_length 2048 Maximum sequence length

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{young2026s0tuning,
  title={S$_0$ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models},
  author={Young, Jack},
  year={2026}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

s0_tuning-0.1.2.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

s0_tuning-0.1.2-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file s0_tuning-0.1.2.tar.gz.

File metadata

  • Download URL: s0_tuning-0.1.2.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for s0_tuning-0.1.2.tar.gz
Algorithm Hash digest
SHA256 d4282995a042019874ff533dc69c8ed8e33b6347284545aefc67a0d573279db0
MD5 ce200500f15f4f1f61332a891c7835c2
BLAKE2b-256 0d4f6b8fe8067a78d01112a20e018d7a1f1e874d16d1adb17ab465818a7a009f

See more details on using hashes here.

File details

Details for the file s0_tuning-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: s0_tuning-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for s0_tuning-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 4bffc03a58e80992189461b9140b47c4b5e2e831361edf79fecc3cfe9063e43c
MD5 8dc5e54314410f9d8a927c590a2c890c
BLAKE2b-256 cf34def825c77752abe0cccc567f8754fbaed1bd4f5d3ebdf3626d111fd35b14

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page