Convert StableHLO models into Apple Core ML format
Project description
stablehlo-coreml
Convert StableHLO models into Apple Core ML format.
StableHLO is the portability layer used by ML frameworks like JAX and PyTorch. This library converts StableHLO programs into Apple's Core ML format via coremltools, enabling deployment on Apple hardware (iOS, macOS, etc.).
Installation
pip install stablehlo-coreml
Requires Python 3.9+ and targets iOS/macOS 18+.
Supported Frameworks
Models can be exported from any framework that produces StableHLO:
- JAX / Flax / Equinox — via
jax.export - PyTorch — via torchax to trace the model into JAX, then
jax.exportto StableHLO
The test suite validates against a broad set of models, including full HuggingFace Transformers such as TinyLlama, T5, DistilBERT, GPT-2, BERT, and Whisper, as well as vision models like ResNet, EfficientNet, ViT, ConvNeXt, and more.
For a real-world example, see gemma-coreml-chat, which exports Google's Gemma 4 model to Core ML using this library.
Converting a Model
To convert a StableHLO module:
import coremltools as ct
from stablehlo_coreml.converter import convert
from stablehlo_coreml import DEFAULT_HLO_PIPELINE
mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18)
cml_model = ct.convert(
mil_program,
source="milinternal",
minimum_deployment_target=ct.target.iOS18,
pass_pipeline=DEFAULT_HLO_PIPELINE,
)
Obtaining a StableHLO Module from JAX
import jax
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir as jax_mlir
from jax.export import export
import jax.numpy as jnp
def jax_function(a, b):
return jnp.einsum("ij,jk -> ik", a, b)
context = jax_mlir.make_ir_context()
input_shapes = (jnp.zeros((2, 4)), jnp.zeros((4, 3)))
jax_exported = export(jax.jit(jax_function))(*input_shapes)
hlo_module = ir.Module.parse(jax_exported.mlir_module(), context=context)
For the JAX example to work, you will additionally need to install absl-py and flatbuffers as dependencies.
Dynamic / symbolic shapes
JAX models exported with symbolic dimensions are supported. Symbolic dims flow
through GetDimensionSizeOp, DynamicBroadcastInDimOp, DynamicIotaOp, and
shape-assertion CustomCallOps automatically, producing CoreML models with
flexible inputs.
import jax
import jax.numpy as jnp
from jax.export import export, symbolic_shape
jax_exported = export(jax.jit(jax_function))(
jax.ShapeDtypeStruct(symbolic_shape("batch, 4"), jnp.float32),
jax.ShapeDtypeStruct((4, 3), jnp.float32),
)
When converting to a CoreML model, specify RangeDim for each symbolic
dimension so the model accepts a range of sizes at inference time:
cml_model = ct.convert(
mil_program,
source="milinternal",
minimum_deployment_target=ct.target.iOS18,
pass_pipeline=DEFAULT_HLO_PIPELINE,
inputs=[
ct.TensorType(name="_arg0", shape=(ct.RangeDim(1, 2048, 1), 4)),
ct.TensorType(name="_arg1", shape=(4, 3)),
],
)
See tests/test_symbolic_shapes.py for
symbolic matmul, batched einsum, and multi-axis patterns (for example
transformer-style projections).
Examples in the test suite
The tests/ directory has end-to-end export and conversion examples:
- PyTorch (torchax) —
tests/pytorch/test_pytorch.py:export_to_stablehlo_module, HuggingFace Transformers, and torchvision models. - JAX —
tests/test_jax.py - Flax / Equinox —
tests/test_flax.py,tests/test_equinox.py
Development
coremltoolssupports up to Python 3.13. Do not run hatch with a newer version. Can be controlled using e.g.export HATCH_PYTHON=python3.13- Run tests using
hatch run test:pytest tests
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file stablehlo_coreml-0.1.1.tar.gz.
File metadata
- Download URL: stablehlo_coreml-0.1.1.tar.gz
- Upload date:
- Size: 53.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: Hatch/1.16.5 cpython/3.12.3 HTTPX/0.28.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f74e2aca385ac62278909d2b70e74675b7646d1bba13af7ee08cb29aac5a128d
|
|
| MD5 |
f5e04837b345d46cb6dc08f26bf061bd
|
|
| BLAKE2b-256 |
878fd27ea96b00cbf4377efea093e5c002a2d927567410a60f5ab3c5521dfb80
|
File details
Details for the file stablehlo_coreml-0.1.1-py3-none-any.whl.
File metadata
- Download URL: stablehlo_coreml-0.1.1-py3-none-any.whl
- Upload date:
- Size: 34.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: Hatch/1.16.5 cpython/3.12.3 HTTPX/0.28.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f080275860ea50ee3818dd676111dc7d6295862597b311398f5cb113da440cb7
|
|
| MD5 |
ef7329e34f9c2491dcc97e881e1f39a9
|
|
| BLAKE2b-256 |
c4b2828f56ae1a098085b0be034c470b445cbfbf190dbae241792862998e6104
|