MLX generation parity utils with HF/Torch-compatible sampling, processors, and steering hooks
Project description
mlx-genkit (formerly mlx-gen-parity)
Small, reusable MLX generation and training toolkit that brings HF/Torch generate() feature parity and clean persona steering to Apple Silicon. It reuses mlx-lm primitives (caches, projections, speculative) and fills the missing parity pieces.
Features
- HF-style
GenerationConfigwith processors/warpers: repetition penalty, no-repeat-ngrams, frequency/presence, bad-words, min_new_tokens,typical_p,epsilon_cutoff. - Constraints:
force_words_ids(strict start + continuation),suppress_tokens,begin_suppress_tokens, multipleeos_token_ids, forced BOS/EOS and per-positionforced_decoder_ids. - Modes: sampling (fast path via mlx-lm), beam (
num_beams,length_penalty,early_stopping), speculative (mlx-lm), sliding KV (max_kv_size). - Hooks: ResidualInjectionHook (sampling) and LogitBiasHook (sampling/beam); SoftPromptHook for training.
- Training (MLX):
loss_forward,xent_loss(label smoothing), mixed-precision compute (bf16) with fp32 master weights. - Training utilities:
sequence_logprob,token_klfor scoring and policy KL. - Model helpers:
ema_update,build_action_mask,stable_softmax; best-effortclone_reference.
Install
- From PyPI (recommended):
pip install mlx-genkit
- Dependencies (if not already installed):
pip install mlx mlx-lm transformers
- From source (editable):
pip install -e .
Models from Hugging Face
- If the repo provides MLX weights (e.g., in
mlx-community), you can load directly:load('mlx-community/<model>'). - For standard HF (PyTorch) repos, convert once using mlx-lm:
- Python:
from mlx_genkit.interop import convert_hf_to_mlx; convert_hf_to_mlx('Qwen/Qwen3-0.6B', quantize=False, local_out='mlx_qwen3_0_6b') - CLI:
mlx_lm.convert --hf-path Qwen/Qwen3-0.6B --mlx-path mlx_qwen3_0_6b - Then load with
load('mlx_qwen3_0_6b').
- Python:
Auto-convert loader
- You can pass either an HF repo id or a local MLX path to
auto_load, which will convert once and cache under./mlx_cache/<sanitized_repo_id>:
from mlx_genkit.loader import auto_load
model, tokenizer, local_path = auto_load('Qwen/Qwen3-0.6B')
print('Loaded from', local_path) # e.g., ./mlx_cache/Qwen_Qwen3-0.6B
Basic usage
from mlx_genkit import GenerationConfig, generate
from mlx_lm import load
model, tokenizer = load('mlx_qwen3_0_6b')
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95, seed=17)
out = generate(model, tokenizer, 'Hello MLX parity', cfg)
print(out['text'])
Beam and constraints
cfg = GenerationConfig(max_tokens=64, temperature=0.0, num_beams=4, early_stopping=True, length_penalty=0.2,
force_words_ids=[tokenizer.encode(' cat')], min_new_tokens=8,
bad_words_ids=[[tokenizer.eos_token_id]], suppress_tokens=[tokenizer.eos_token_id])
out = generate(model, tokenizer, 'The', cfg)
Speculative decoding
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95,
use_speculative=True, draft_model_id='mlx_qwen3_0_6b', num_draft_tokens=3)
out = generate(model, tokenizer, 'Speculative test', cfg)
Persona steering
import mlx.core as mx
from mlx_genkit import LogitBiasHook
H = model.args.hidden_size
model['_persona_v'] = mx.random.normal((H,)) * (1.0/(H**0.5))
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, 'Summarize MLX', cfg, hooks=[LogitBiasHook(param_key='_persona_v', alpha=1.2)])
Training (MLX)
from mlx_genkit import TrainingConfig, train_step, SoftPromptHook
from mlx.optimizers import AdamW
pad_id = getattr(tokenizer, 'pad_token_id', -100) or -100
opt = AdamW(learning_rate=2e-4)
batch = {'tokens': ...} # mx.array [B, T]
cfg = TrainingConfig(dtype='bf16', loss_scale=1024.0)
loss = train_step(model, batch, opt, cfg, hooks=[SoftPromptHook(n_virtual=10, param_key='_soft_prompt')], pad_id=pad_id)
Utilities
from mlx_genkit import sequence_logprob, token_kl, ema_update, build_action_mask
# Per-sample mean log-prob on supervised positions (labels == -100 are ignored)
lp = sequence_logprob(model, batch_tokens, labels) # [B]
# KL(pi || pref) averaged over supervised positions
kl = token_kl(model, ref_model, batch_tokens, labels) # [B]
# EMA update of a target model from a source model
ema_update(target_model, model, decay=0.999)
# Supervised mask after prompt
mask = build_action_mask(prompt_lens=[12, 20], seq_len=T) # [B, T] bool
Parity testing
- Torch vs MLX:
python -m mlx_gen_parity.tests.parity_hf --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt 'hello' - Suite (8 prompts):
python -m mlx_gen_parity.tests.parity_suite --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b
CLI wrapper
mlxgk-generate \
--model Qwen/Qwen3-0.6B \
--prompt "Hello MLX" \
--max-tokens 64 --temp 0.7 --top-p 0.95 \
--num-beams 1 --no-repeat-ngram-size 2
Performance bench
python -m mlx_gen_parity.tests.perf_bench --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt "Hello performance" --max-tokens 64
Releases
- Bump version across files:
make bump-version PART=patch(orminor/major)
- Create and push a git tag (vX.Y.Z):
make git-release- This tags and pushes the repo; PyPI packaging can be added later.
Notes
- Parity targets control‑surface equivalence: constraints, stops, finish reasons, determinism; token streams may differ across frameworks/devices.
- Sampling fast path reuses mlx-lm’s decoding loop and caches for best performance on Apple Silicon.
Renaming
- This project was previously published as
mlx-gen-parity. The Python importmlx_gen_paritycontinues to work for compatibility, but new code should prefermlx_genkitandpip install mlx-genkit.
Known limitations
- Residual injection uses Python-level patching; highly optimized/compiled paths may bypass it. Use
forward_with_hidden(..., strict=True)when you need deterministic capture/injection semantics. - Some MLX model classes may not accept
input_embeddings(used for soft prompts in training). In those cases, the library now falls back gracefully to standard token-only forward. - Beam search applies processors on raw logits and then normalizes (HF behavior). Earlier parity reports in this repo may reflect the previous implementation on normalized logprobs.
Tips
- When running examples directly from the repo, make sure you’re using the local sources:
pip install -e .or run withPYTHONPATH=.. - Parity/perf harnesses will download HF models; ensure network access and sufficient disk space.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_genkit-0.2.0.tar.gz.
File metadata
- Download URL: mlx_genkit-0.2.0.tar.gz
- Upload date:
- Size: 36.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
afa8aadee0196a5fb3dbb898b9eec6d5f31991ac402159261683b29e13ca25dd
|
|
| MD5 |
2239ab24bfd94544d5d291490ddeceed
|
|
| BLAKE2b-256 |
1056b84096512e52eaffd05c6b365098204456d7bc6f4f068414caf68b3d6264
|
File details
Details for the file mlx_genkit-0.2.0-py3-none-any.whl.
File metadata
- Download URL: mlx_genkit-0.2.0-py3-none-any.whl
- Upload date:
- Size: 41.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9ad360a2e581ba05dc6093a848100d9f78391b93d4a4824a3c6be562f12b8a18
|
|
| MD5 |
2e54114a841f944e8239f49104df11fd
|
|
| BLAKE2b-256 |
2fcf7271723e9c50037abacaa9b498ea5cebfee609d89b901c8b756a59bd0aea
|