Skip to main content

Convert StableHLO models into Apple Core ML format

Project description

stablehlo-coreml

Convert StableHLO models into Apple Core ML format.

StableHLO is the portability layer used by ML frameworks like JAX and PyTorch. This library converts StableHLO programs into Apple's Core ML format via coremltools, enabling deployment on Apple hardware (iOS, macOS, etc.).

Installation

pip install stablehlo-coreml

Requires Python 3.9+ and targets iOS/macOS 18+.

Supported Frameworks

Models can be exported from any framework that produces StableHLO:

  • JAX / Flax / Equinox — via jax.export
  • PyTorch — via torchax to trace the model into JAX, then jax.export to StableHLO

The test suite validates against a broad set of models, including full HuggingFace Transformers such as TinyLlama, T5, DistilBERT, GPT-2, BERT, and Whisper, as well as vision models like ResNet, EfficientNet, ViT, ConvNeXt, and more.

For a real-world example, see gemma-coreml-chat, which exports Google's Gemma 4 model to Core ML using this library.

Converting a Model

To convert a StableHLO module:

import coremltools as ct
from stablehlo_coreml.converter import convert
from stablehlo_coreml import DEFAULT_HLO_PIPELINE

mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18)
cml_model = ct.convert(
    mil_program,
    source="milinternal",
    minimum_deployment_target=ct.target.iOS18,
    pass_pipeline=DEFAULT_HLO_PIPELINE,
)

Obtaining a StableHLO Module from JAX

import jax
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir as jax_mlir
from jax.export import export

import jax.numpy as jnp

def jax_function(a, b):
    return jnp.einsum("ij,jk -> ik", a, b)

context = jax_mlir.make_ir_context()
input_shapes = (jnp.zeros((2, 4)), jnp.zeros((4, 3)))
jax_exported = export(jax.jit(jax_function))(*input_shapes)
hlo_module = ir.Module.parse(jax_exported.mlir_module(), context=context)

For the JAX example to work, you will additionally need to install absl-py and flatbuffers as dependencies.

Dynamic / symbolic shapes

JAX models exported with symbolic dimensions are supported. Symbolic dims flow through GetDimensionSizeOp, DynamicBroadcastInDimOp, DynamicIotaOp, and shape-assertion CustomCallOps automatically, producing CoreML models with flexible inputs.

import jax
import jax.numpy as jnp
from jax.export import export, symbolic_shape

jax_exported = export(jax.jit(jax_function))(
    jax.ShapeDtypeStruct(symbolic_shape("batch, 4"), jnp.float32),
    jax.ShapeDtypeStruct((4, 3), jnp.float32),
)

When converting to a CoreML model, specify RangeDim for each symbolic dimension so the model accepts a range of sizes at inference time:

cml_model = ct.convert(
    mil_program,
    source="milinternal",
    minimum_deployment_target=ct.target.iOS18,
    pass_pipeline=DEFAULT_HLO_PIPELINE,
    inputs=[
        ct.TensorType(name="_arg0", shape=(ct.RangeDim(1, 2048, 1), 4)),
        ct.TensorType(name="_arg1", shape=(4, 3)),
    ],
)

See tests/test_symbolic_shapes.py for symbolic matmul, batched einsum, and multi-axis patterns (for example transformer-style projections).

Examples in the test suite

The tests/ directory has end-to-end export and conversion examples:

Development

  • coremltools supports up to Python 3.13. Do not run hatch with a newer version. Can be controlled using e.g. export HATCH_PYTHON=python3.13
  • Run tests using hatch run test:pytest tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stablehlo_coreml-0.1.1.tar.gz (53.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stablehlo_coreml-0.1.1-py3-none-any.whl (34.9 kB view details)

Uploaded Python 3

File details

Details for the file stablehlo_coreml-0.1.1.tar.gz.

File metadata

  • Download URL: stablehlo_coreml-0.1.1.tar.gz
  • Upload date:
  • Size: 53.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Hatch/1.16.5 cpython/3.12.3 HTTPX/0.28.1

File hashes

Hashes for stablehlo_coreml-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f74e2aca385ac62278909d2b70e74675b7646d1bba13af7ee08cb29aac5a128d
MD5 f5e04837b345d46cb6dc08f26bf061bd
BLAKE2b-256 878fd27ea96b00cbf4377efea093e5c002a2d927567410a60f5ab3c5521dfb80

See more details on using hashes here.

File details

Details for the file stablehlo_coreml-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: stablehlo_coreml-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 34.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Hatch/1.16.5 cpython/3.12.3 HTTPX/0.28.1

File hashes

Hashes for stablehlo_coreml-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f080275860ea50ee3818dd676111dc7d6295862597b311398f5cb113da440cb7
MD5 ef7329e34f9c2491dcc97e881e1f39a9
BLAKE2b-256 c4b2828f56ae1a098085b0be034c470b445cbfbf190dbae241792862998e6104

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page