Skip to main content

MLX generation parity utils with HF/Torch-compatible sampling, processors, and steering hooks

Project description

mlx-gen-parity

Small, reusable MLX decoding and training library that brings HF/Torch generate() feature parity and clean persona steering to Apple Silicon. It reuses mlx-lm primitives (caches, projections, speculative) and fills the missing parity pieces.

Features

  • HF-style GenerationConfig with processors/warpers: repetition penalty, no-repeat-ngrams, frequency/presence, bad-words, min_new_tokens, typical_p, epsilon_cutoff.
  • Constraints: force_words_ids (strict start + continuation), suppress_tokens, begin_suppress_tokens, multiple eos_token_ids, forced BOS/EOS and per-position forced_decoder_ids.
  • Modes: sampling (fast path via mlx-lm), beam (num_beams, length_penalty, early_stopping), speculative (mlx-lm), sliding KV (max_kv_size).
  • Hooks: ResidualInjectionHook (sampling) and LogitBiasHook (sampling/beam); SoftPromptHook for training.
  • Training (MLX): loss_forward, xent_loss (label smoothing), mixed-precision compute (bf16) with fp32 master weights.

Install

  • From PyPI (recommended):
pip install mlx-gen-parity
  • Dependencies (if not already installed):
pip install mlx mlx-lm transformers
  • From source (editable):
pip install -e .

Models from Hugging Face

  • If the repo provides MLX weights (e.g., in mlx-community), you can load directly: load('mlx-community/<model>').
  • For standard HF (PyTorch) repos, convert once using mlx-lm:
    • Python: from mlx_gen_parity.interop import convert_hf_to_mlx; convert_hf_to_mlx('Qwen/Qwen3-0.6B', quantize=False, local_out='mlx_qwen3_0_6b')
    • CLI: mlx_lm.convert --hf-path Qwen/Qwen3-0.6B --mlx-path mlx_qwen3_0_6b
    • Then load with load('mlx_qwen3_0_6b').

Basic usage

from mlx_gen_parity import GenerationConfig, generate
from mlx_lm import load

model, tokenizer = load('mlx_qwen3_0_6b')
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95, seed=17)
out = generate(model, tokenizer, 'Hello MLX parity', cfg)
print(out['text'])

Beam and constraints

cfg = GenerationConfig(max_tokens=64, temperature=0.0, num_beams=4, early_stopping=True, length_penalty=0.2,
                       force_words_ids=[tokenizer.encode(' cat')], min_new_tokens=8,
                       bad_words_ids=[[tokenizer.eos_token_id]], suppress_tokens=[tokenizer.eos_token_id])
out = generate(model, tokenizer, 'The', cfg)

Speculative decoding

cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95,
                       use_speculative=True, draft_model_id='mlx_qwen3_0_6b', num_draft_tokens=3)
out = generate(model, tokenizer, 'Speculative test', cfg)

Persona steering

import mlx.core as mx
from mlx_gen_parity import LogitBiasHook
H = model.args.hidden_size
model['_persona_v'] = mx.random.normal((H,)) * (1.0/(H**0.5))
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, 'Summarize MLX', cfg, hooks=[LogitBiasHook(param_key='_persona_v', alpha=1.2)])

Training (MLX)

from mlx_gen_parity import TrainingConfig, train_step, SoftPromptHook
from mlx.optimizers import AdamW
pad_id = getattr(tokenizer, 'pad_token_id', -100) or -100
opt = AdamW(learning_rate=2e-4)
batch = {'tokens': ...}  # mx.array [B, T]
cfg = TrainingConfig(dtype='bf16', loss_scale=1024.0)
loss = train_step(model, batch, opt, cfg, hooks=[SoftPromptHook(n_virtual=10, param_key='_soft_prompt')], pad_id=pad_id)

Parity testing

  • Torch vs MLX: python -m mlx_gen_parity.tests.parity_hf --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt 'hello'
  • Suite (8 prompts): python -m mlx_gen_parity.tests.parity_suite --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b

CLI wrapper

mlxgp-generate \
  --model Qwen/Qwen3-0.6B \
  --prompt "Hello MLX" \
  --max-tokens 64 --temp 0.7 --top-p 0.95 \
  --num-beams 1 --no-repeat-ngram-size 2

Performance bench

python -m mlx_gen_parity.tests.perf_bench --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt "Hello performance" --max-tokens 64

Releases

  • Bump version across files:
    • make bump-version PART=patch (or minor/major)
  • Create and push a git tag (vX.Y.Z):
    • make git-release
    • This tags and pushes the repo; PyPI packaging can be added later.

Notes

  • Parity targets control‑surface equivalence: constraints, stops, finish reasons, determinism; token streams may differ across frameworks/devices.
  • Sampling fast path reuses mlx-lm’s decoding loop and caches for best performance on Apple Silicon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_gen_parity-0.1.1.tar.gz (31.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_gen_parity-0.1.1-py3-none-any.whl (36.5 kB view details)

Uploaded Python 3

File details

Details for the file mlx_gen_parity-0.1.1.tar.gz.

File metadata

  • Download URL: mlx_gen_parity-0.1.1.tar.gz
  • Upload date:
  • Size: 31.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for mlx_gen_parity-0.1.1.tar.gz
Algorithm Hash digest
SHA256 94908b824af5698f8c8f0281cfd67cf4eb315dcd4f2e22660f9d37bb86393af7
MD5 90e9f39b931769c1026cdee379f8708b
BLAKE2b-256 8fa6a8fd86fa8381d1c19771ebbaba418baa1518d988a9f5b11cb272106cbd39

See more details on using hashes here.

File details

Details for the file mlx_gen_parity-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: mlx_gen_parity-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 36.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for mlx_gen_parity-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ace8af053a4ba972de87dd15d67e20f550b39cdb2b35824ef24081ca21dcc1e4
MD5 c5047572f4276b30b3491bc4980a8733
BLAKE2b-256 f33183afe0333fb1fc97f46472230a0c0e18e12768fec5d14285044a9e8d5eb5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page