MLX generation parity utils with HF/Torch-compatible sampling, processors, and steering hooks
Project description
mlx-gen-parity
Small, reusable MLX decoding and training library that brings HF/Torch generate() feature parity and clean persona steering to Apple Silicon. It reuses mlx-lm primitives (caches, projections, speculative) and fills the missing parity pieces.
Features
- HF-style
GenerationConfigwith processors/warpers: repetition penalty, no-repeat-ngrams, frequency/presence, bad-words, min_new_tokens,typical_p,epsilon_cutoff. - Constraints:
force_words_ids(strict start + continuation),suppress_tokens,begin_suppress_tokens, multipleeos_token_ids, forced BOS/EOS and per-positionforced_decoder_ids. - Modes: sampling (fast path via mlx-lm), beam (
num_beams,length_penalty,early_stopping), speculative (mlx-lm), sliding KV (max_kv_size). - Hooks: ResidualInjectionHook (sampling) and LogitBiasHook (sampling/beam); SoftPromptHook for training.
- Training (MLX):
loss_forward,xent_loss(label smoothing), mixed-precision compute (bf16) with fp32 master weights.
Install
- From PyPI (recommended):
pip install mlx-gen-parity
- Dependencies (if not already installed):
pip install mlx mlx-lm transformers
- From source (editable):
pip install -e .
Models from Hugging Face
- If the repo provides MLX weights (e.g., in
mlx-community), you can load directly:load('mlx-community/<model>'). - For standard HF (PyTorch) repos, convert once using mlx-lm:
- Python:
from mlx_gen_parity.interop import convert_hf_to_mlx; convert_hf_to_mlx('Qwen/Qwen3-0.6B', quantize=False, local_out='mlx_qwen3_0_6b') - CLI:
mlx_lm.convert --hf-path Qwen/Qwen3-0.6B --mlx-path mlx_qwen3_0_6b - Then load with
load('mlx_qwen3_0_6b').
- Python:
Basic usage
from mlx_gen_parity import GenerationConfig, generate
from mlx_lm import load
model, tokenizer = load('mlx_qwen3_0_6b')
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95, seed=17)
out = generate(model, tokenizer, 'Hello MLX parity', cfg)
print(out['text'])
Beam and constraints
cfg = GenerationConfig(max_tokens=64, temperature=0.0, num_beams=4, early_stopping=True, length_penalty=0.2,
force_words_ids=[tokenizer.encode(' cat')], min_new_tokens=8,
bad_words_ids=[[tokenizer.eos_token_id]], suppress_tokens=[tokenizer.eos_token_id])
out = generate(model, tokenizer, 'The', cfg)
Speculative decoding
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95,
use_speculative=True, draft_model_id='mlx_qwen3_0_6b', num_draft_tokens=3)
out = generate(model, tokenizer, 'Speculative test', cfg)
Persona steering
import mlx.core as mx
from mlx_gen_parity import LogitBiasHook
H = model.args.hidden_size
model['_persona_v'] = mx.random.normal((H,)) * (1.0/(H**0.5))
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, 'Summarize MLX', cfg, hooks=[LogitBiasHook(param_key='_persona_v', alpha=1.2)])
Training (MLX)
from mlx_gen_parity import TrainingConfig, train_step, SoftPromptHook
from mlx.optimizers import AdamW
pad_id = getattr(tokenizer, 'pad_token_id', -100) or -100
opt = AdamW(learning_rate=2e-4)
batch = {'tokens': ...} # mx.array [B, T]
cfg = TrainingConfig(dtype='bf16', loss_scale=1024.0)
loss = train_step(model, batch, opt, cfg, hooks=[SoftPromptHook(n_virtual=10, param_key='_soft_prompt')], pad_id=pad_id)
Parity testing
- Torch vs MLX:
python -m mlx_gen_parity.tests.parity_hf --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt 'hello' - Suite (8 prompts):
python -m mlx_gen_parity.tests.parity_suite --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b
CLI wrapper
mlxgp-generate \
--model Qwen/Qwen3-0.6B \
--prompt "Hello MLX" \
--max-tokens 64 --temp 0.7 --top-p 0.95 \
--num-beams 1 --no-repeat-ngram-size 2
Performance bench
python -m mlx_gen_parity.tests.perf_bench --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt "Hello performance" --max-tokens 64
Releases
- Bump version across files:
make bump-version PART=patch(orminor/major)
- Create and push a git tag (vX.Y.Z):
make git-release- This tags and pushes the repo; PyPI packaging can be added later.
Notes
- Parity targets control‑surface equivalence: constraints, stops, finish reasons, determinism; token streams may differ across frameworks/devices.
- Sampling fast path reuses mlx-lm’s decoding loop and caches for best performance on Apple Silicon.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_gen_parity-0.1.1.tar.gz.
File metadata
- Download URL: mlx_gen_parity-0.1.1.tar.gz
- Upload date:
- Size: 31.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
94908b824af5698f8c8f0281cfd67cf4eb315dcd4f2e22660f9d37bb86393af7
|
|
| MD5 |
90e9f39b931769c1026cdee379f8708b
|
|
| BLAKE2b-256 |
8fa6a8fd86fa8381d1c19771ebbaba418baa1518d988a9f5b11cb272106cbd39
|
File details
Details for the file mlx_gen_parity-0.1.1-py3-none-any.whl.
File metadata
- Download URL: mlx_gen_parity-0.1.1-py3-none-any.whl
- Upload date:
- Size: 36.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ace8af053a4ba972de87dd15d67e20f550b39cdb2b35824ef24081ca21dcc1e4
|
|
| MD5 |
c5047572f4276b30b3491bc4980a8733
|
|
| BLAKE2b-256 |
f33183afe0333fb1fc97f46472230a0c0e18e12768fec5d14285044a9e8d5eb5
|