Skip to main content

Tune the initial recurrent state of hybrid models. Zero inference overhead.

Project description

S₀ Tuning

S₀ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models

Jack Young

Paper | Website

S₀ Tuning overview: learned initial states injected into recurrent layers with zero inference overhead

S₀ Tuning optimizes the initial hidden state (S₀) of recurrent layers in hybrid architectures (GatedDeltaNet, Mamba-2). The learned states are injected before each forward pass, adding zero latency at inference since recurrent models already maintain state. Trained states are 48 MB, training takes ~3 minutes on one GPU.

Results

All results on Qwen3.5-4B with 20 optimization steps per task. p-values from Welch's t-test (unequal variances) across independent seed runs.

Benchmark Base + S₀ Tuning Delta Seeds p-value
HumanEval 48.8% 72.2% +23.6pp 10 < 10⁻¹¹
MATH-500 51.4% 56.2% +4.8pp 8 0.00002
GSM8K 85.3% 88.1% +2.8pp 10 0.0003
Spider (boundary test) ~72% ~72% +0.0pp 5 alphas n.s.

The Spider result (no improvement on out-of-distribution SQL) supports the trajectory-steering mechanism: S₀ Tuning biases the model toward solution trajectories already in its distribution, rather than injecting new knowledge.

Scaling. Gains increase with model size: +2.7pp at 0.8B, +23.6pp at 4B, +44.0pp at 9B (HumanEval).

Cross-architecture. On FalconH1-7B (Mamba-2): 71.8% vs 71.4% LoRA baseline (experimental).

Usage

from s0_tuning import S0Trainer

trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")

prompt = "Q: What is 2+2?\nA:"
answer = " 4"
full_text = prompt + answer

# prompt_length = number of prompt tokens (for completion-only loss masking)
tokens = trainer.tokenizer(prompt)
prompt_length = len(tokens["input_ids"])

data = [(full_text, prompt_length)]
trainer.train(data)
trainer.activate()

output = trainer.generate("Q: What is 2+2?\nA:")
print(output)

trainer.save("./my_s0_states")

Loading saved states on a new instance:

trainer = S0Trainer.from_pretrained("Qwen/Qwen3.5-4B")
trainer.load("./my_s0_states")
trainer.activate()
output = trainer.generate("Q: What is 2+2?\nA:")

Installation

pip install git+https://github.com/jackyoung27/s0-tuning.git

Requires PyTorch 2.0+, transformers >= 4.51.0, and a GPU with >= 10 GB VRAM for training.

Supported Models

Model Family Architecture Recurrent Layers Default Alpha Status
Qwen3.5 GatedDeltaNet linear_attention layers 0.07 Tested
FalconH1 Mamba-2 mamba layers 0.65 Experimental

To add a new hybrid architecture, the model needs (1) identifiable recurrent layers and (2) an initial_state argument in the recurrent kernel. See _detect_architecture and _patch_gdn / _patch_mamba2 in trainer.py.

How It Works

Recurrent models process sequences starting from S₀ = 0. S₀ Tuning creates learnable state tensors, patches them into the model as initial_state, and optimizes via next-token prediction loss on correct completions. At inference, learned states are scaled by alpha (e.g. 0.07 for GatedDeltaNet) to prevent distribution shift. The result: the model starts each sequence from a task-informed state rather than zeros, biasing generation toward correct solution trajectories without modifying any model weights.

Configuration

S0Config controls training:

Parameter Default Description
n_steps 20 Optimization steps
lr 1e-3 Learning rate
l2_lambda 5e-4 L2 regularization weight
alpha None Scaling factor (auto-detected per architecture if None)
normalize False Normalize states before scaling
grad_clip 1.0 Gradient clipping norm
max_length 2048 Maximum sequence length

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{young2026s0tuning,
  title={S$_0$ Tuning: Zero-Overhead Adaptation of Hybrid Recurrent-Attention Models},
  author={Young, Jack},
  year={2026}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

s0_tuning-0.1.0.tar.gz (12.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

s0_tuning-0.1.0-py3-none-any.whl (10.6 kB view details)

Uploaded Python 3

File details

Details for the file s0_tuning-0.1.0.tar.gz.

File metadata

  • Download URL: s0_tuning-0.1.0.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for s0_tuning-0.1.0.tar.gz
Algorithm Hash digest
SHA256 01246e30fa88f22f5c4687ea176e419a22b10a2f22a98255d16ec03aad06e8ed
MD5 3a2cc3426dbcd99e612d537e8977cac0
BLAKE2b-256 b46f46695a4320375e50e42807ac850e613ac37b81712bf813b89c4e763603a8

See more details on using hashes here.

File details

Details for the file s0_tuning-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: s0_tuning-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 10.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.6

File hashes

Hashes for s0_tuning-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e4e97acf7616dc7860f9982ae920d587451e92995f2868473ed4419654be389f
MD5 f9cf027ede39ae76fe56996e64a6dba2
BLAKE2b-256 9da216fa5efb3229d3442c29aa3daf36573f37178f5a999253bc8a1f1f2edce6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page