Skip to main content

MLX generation parity utils with HF/Torch-compatible sampling, processors, and steering hooks

Project description

mlx-genkit

Small, reusable MLX generation and training toolkit that brings HF/Torch generate() feature parity and clean persona steering to Apple Silicon. It reuses mlx-lm primitives (caches, projections, speculative) and fills the missing parity pieces.

Features

  • HF-style GenerationConfig with processors/warpers: repetition penalty, no-repeat-ngrams, frequency/presence, bad-words, min_new_tokens, typical_p, epsilon_cutoff.
  • Constraints: force_words_ids (strict start + continuation), suppress_tokens, begin_suppress_tokens, multiple eos_token_ids, forced BOS/EOS and per-position forced_decoder_ids.
  • Modes: sampling (fast path via mlx-lm), beam (num_beams, length_penalty, early_stopping), speculative (mlx-lm), sliding KV (max_kv_size).
  • Hooks: ResidualInjectionHook (sampling) and LogitBiasHook (sampling/beam); SoftPromptHook for training.
  • Training (MLX): loss_forward, xent_loss (label smoothing), mixed-precision compute (bf16) with fp32 master weights.
  • Training utilities: sequence_logprob, token_kl for scoring and policy KL.
  • Model helpers: ema_update, build_action_mask, stable_softmax; best-effort clone_reference.

Install

  • From PyPI (recommended):
pip install mlx-genkit
  • Dependencies (if not already installed):
pip install mlx mlx-lm transformers
  • From source (editable):
pip install -e .

Models from Hugging Face

  • If the repo provides MLX weights (e.g., in mlx-community), you can load directly: load('mlx-community/<model>').
  • For standard HF (PyTorch) repos, convert once using mlx-lm:
    • Python: from mlx_genkit.interop import convert_hf_to_mlx; convert_hf_to_mlx('Qwen/Qwen3-0.6B', quantize=False, local_out='mlx_qwen3_0_6b')
    • CLI: mlx_lm.convert --hf-path Qwen/Qwen3-0.6B --mlx-path mlx_qwen3_0_6b
    • Then load with load('mlx_qwen3_0_6b').

Auto-convert loader

  • You can pass either an HF repo id or a local MLX path to auto_load, which will convert once and cache under ./mlx_cache/<sanitized_repo_id>:
from mlx_genkit.loader import auto_load
model, tokenizer, local_path = auto_load('Qwen/Qwen3-0.6B')
print('Loaded from', local_path)  # e.g., ./mlx_cache/Qwen_Qwen3-0.6B

Basic usage

from mlx_genkit import GenerationConfig, generate
from mlx_lm import load

model, tokenizer = load('mlx_qwen3_0_6b')
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95, seed=17)
out = generate(model, tokenizer, 'Hello MLX parity', cfg)
print(out['text'])

Chat prompts (auto chat template)

# If you pass a list of HF-style messages, mlx-genkit will automatically
# apply the tokenizer's chat template when available.
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Summarize MLX in 3 bullets."},
]
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, messages, cfg)
print(out['text'])

Auto-apply chat template for plain prompts

# You can also provide a plain string and have mlx-genkit wrap it
# using the model's chat template when available. This is enabled
# automatically if the tokenizer defines a chat_template; you can
# force or disable it via GenerationConfig, or force with assume_user_chat.
cfg = GenerationConfig(max_tokens=64, temperature=0.7, auto_chat_template=True, system_prompt="You are helpful")
# Equivalent explicit flag:
# cfg = GenerationConfig(max_tokens=64, temperature=0.7, assume_user_chat=True, system_prompt="You are helpful")
out = generate(model, tokenizer, "Summarize MLX in 3 bullets.", cfg)
print(out['text'])

Beam and constraints

cfg = GenerationConfig(max_tokens=64, temperature=0.0, num_beams=4, early_stopping=True, length_penalty=0.2,
                       force_words_ids=[tokenizer.encode(' cat')], min_new_tokens=8,
                       bad_words_ids=[[tokenizer.eos_token_id]], suppress_tokens=[tokenizer.eos_token_id])
out = generate(model, tokenizer, 'The', cfg)

Speculative decoding

cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95,
                       use_speculative=True, draft_model_id='mlx_qwen3_0_6b', num_draft_tokens=3)
out = generate(model, tokenizer, 'Speculative test', cfg)

Persona steering

import mlx.core as mx
from mlx_genkit import LogitBiasHook
H = model.args.hidden_size
model['_persona_v'] = mx.random.normal((H,)) * (1.0/(H**0.5))
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, 'Summarize MLX', cfg, hooks=[LogitBiasHook(param_key='_persona_v', alpha=1.2)])

Training (MLX)

from mlx_genkit import TrainingConfig, train_step, SoftPromptHook
from mlx.optimizers import AdamW
pad_id = getattr(tokenizer, 'pad_token_id', -100) or -100
opt = AdamW(learning_rate=2e-4)
batch = {'tokens': ...}  # mx.array [B, T]
cfg = TrainingConfig(dtype='bf16', loss_scale=1024.0)
loss = train_step(model, batch, opt, cfg, hooks=[SoftPromptHook(n_virtual=10, param_key='_soft_prompt')], pad_id=pad_id)

Utilities

from mlx_genkit import sequence_logprob, token_kl, ema_update, build_action_mask

# Per-sample mean log-prob on supervised positions (labels == -100 are ignored)
lp = sequence_logprob(model, batch_tokens, labels)  # [B]

# KL(pi || pref) averaged over supervised positions
kl = token_kl(model, ref_model, batch_tokens, labels)  # [B]

# EMA update of a target model from a source model
ema_update(target_model, model, decay=0.999)

# Supervised mask after prompt
mask = build_action_mask(prompt_lens=[12, 20], seq_len=T)  # [B, T] bool

Parity testing

  • Torch vs MLX: python -m mlx_genkit.tests.parity_hf --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt 'hello'
  • Suite (8 prompts): python -m mlx_genkit.tests.parity_suite --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b

CLI wrapper

mlxgk-generate \
  --model Qwen/Qwen3-0.6B \
  --prompt "Hello MLX" \
  --max-tokens 64 --temp 0.7 --top-p 0.95 \
  --num-beams 1 --no-repeat-ngram-size 2

CLI chat and stop strings

  • Chat: --messages-json '[{"role":"user","content":"hi"}]' (auto-applies template)
  • Auto chat for plain prompts: add --auto-chat (or disable with --no-auto-chat); optional --system "You are helpful"
  • Force treating plain prompts as user messages: --assume-user-chat (equivalent to --auto-chat)
  • Stop strings: use --stop or the alias --stop-strings (comma-separated)

Defaults

  • The CLI will, by default, auto-apply chat templates when the loaded tokenizer exposes a chat template (has apply_chat_template and a non-empty chat_template). Use --no-auto-chat to turn this off.

Prefetch and convert (download)

# Download an HF repo and convert to MLX format without loading into memory.
# Prints the local path, e.g., ./mlx_cache/Qwen_Qwen2-7B-Instruct
mlxgk-download --model Qwen/Qwen2-7B-Instruct

# Options
#  --cache-dir DIR           Cache location (default: ./mlx_cache)
#  --quantize                Quantize during conversion
#  --trust-remote-code       Allow custom code from the repo
#  --force                   Reconvert and overwrite existing cache

Performance bench

python -m mlx_genkit.tests.perf_bench --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt "Hello performance" --max-tokens 64

Releases

  • Bump version across files (defaults to patch):
    • make bump-version (use PART=minor or PART=major to override)
  • Create and push a git tag (vX.Y.Z):
    • make git-release
    • This tags and pushes the repo; PyPI packaging can be added later.

Notes

  • Parity targets control‑surface equivalence: constraints, stops, finish reasons, determinism; token streams may differ across frameworks/devices.
  • Sampling fast path reuses mlx-lm’s decoding loop and caches for best performance on Apple Silicon.

Known limitations

  • Residual injection uses Python-level patching; highly optimized/compiled paths may bypass it. Use forward_with_hidden(..., strict=True) when you need deterministic capture/injection semantics.
  • Some MLX model classes may not accept input_embeddings (used for soft prompts in training). In those cases, the library now falls back gracefully to standard token-only forward.
  • Beam search applies processors on raw logits and then normalizes (HF behavior). Earlier parity reports in this repo may reflect the previous implementation on normalized logprobs.

Tips

  • When running examples directly from the repo, make sure you’re using the local sources: pip install -e . or run with PYTHONPATH=..
  • Parity/perf harnesses will download HF models; ensure network access and sufficient disk space.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_genkit-0.3.3.tar.gz (39.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_genkit-0.3.3-py3-none-any.whl (44.0 kB view details)

Uploaded Python 3

File details

Details for the file mlx_genkit-0.3.3.tar.gz.

File metadata

  • Download URL: mlx_genkit-0.3.3.tar.gz
  • Upload date:
  • Size: 39.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for mlx_genkit-0.3.3.tar.gz
Algorithm Hash digest
SHA256 8c691d575d2403dd0a36b2e9d006a7d0e600d23d52032a87fd22a001fdd727d5
MD5 4a94713c1be052d258e7f5b284ce734b
BLAKE2b-256 c9664db230596d4ae7ffd049cba67573e6ab02f37d29bc37cd6984f89adbbf53

See more details on using hashes here.

File details

Details for the file mlx_genkit-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: mlx_genkit-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 44.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for mlx_genkit-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5a76f1a1eb370cfae2884ef23f22e83d650b42cb969e141bad0050f17a536773
MD5 59310c8921fc71640357263820ff7934
BLAKE2b-256 a0be9fe76bc7cd6d3e44f1c906429e026ecb1c32c9095c407e1f30f068c4aff7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page