MLX-native SAM models for segmentation and video tracking
Project description
mlx-sam
MLX-native SAM models for Apple Silicon. The first supported model family is Meta SAM 2.1 for interactive image segmentation and video object tracking.
The goal of this repo is practical local video segmentation: load an MLX SAM2 checkpoint, click objects in a video, add positive/negative corrections, track forward or backward from the edited frame, and render masks back to reviewable overlay videos. The default runtime is Python 3.14 + MLX and does not install PyTorch.
What You Can Do
- Segment an image from points or boxes.
- Track one or more objects through a video with SAM2 memory.
- Add positive and negative correction clicks across frames.
- Start from the middle of a clip and propagate forward, backward, or both.
- Use box prompts in the video flow.
- Render
.npyor.npzmasks as overlay videos for visual inspection. - Convert SAM2.1 Hugging Face checkpoints into MLX
.safetensors. - Load converted checkpoints from local disk or Hugging Face.
The public video predictor mirrors the official SAM2 method names where the implemented behavior matches closely:
SAM2VideoPredictor.from_pretrained(...)init_state(...)add_new_points_or_box(...)add_new_points(...)add_new_mask(...)propagate_in_video(...)clear_all_prompts_in_frame(...)reset_state(...)
Benchmarks
Current indicative results on this machine with facebook/sam2.1-hiera-small.
Generated reports live under outputs/benchmarks/.
| Workload | Torch/MPS | MLX | Result |
|---|---|---|---|
| Image encoder | 104 ms/frame |
81 ms/frame |
MLX 1.28x faster |
| Cached prompt decode | n/a | 4 ms |
interactive mask decode |
| Full image + prompt | n/a | 85 ms |
image embedding plus prompt |
| Full dog video, post-prompt propagation | 331 ms/frame |
189 ms/frame |
MLX 1.75x faster with feature precompute |
| Full dog video, total run | 100.5 s |
94.8 s |
MLX faster end to end |
| Raw propagation, no save/overlay/final resize | 407 ms/frame |
287 ms/frame |
MLX 1.42x faster |
The fastest video path uses batched image-feature precompute and bfloat16 memory attention:
uv run python scripts/benchmark_video_memory_mlx.py \
--model-id facebook/sam2.1-hiera-small \
--weights checkpoints/sam2.1_hiera_small_image_segmenter.safetensors \
--frames-dir outputs/video_memory_multiclick/small_frames_full \
--precompute-image-features \
--feature-batch-size 4 \
--memory-dtype bfloat16 \
--memory-attention-dtype bfloat16 \
--output-mask outputs/video_memory_multiclick/small_mlx_masks_full_precompute_bf16_attn.npy \
--output-video outputs/video_memory_multiclick/small_mlx_overlay_full_precompute_bf16_attn.mp4 \
--report outputs/benchmarks/video_memory_multiclick_small_mlx_full_precompute_bf16_attn.json \
--points 625 429 700 470 300 250 950 610 \
--labels 1 1 0 0
Torch comparison command:
uv run --extra torch-parity python scripts/benchmark_video_memory_torch.py \
--model-id facebook/sam2.1-hiera-small \
--checkpoint checkpoints/sam2.1_hiera_small.pt \
--frames-dir outputs/video_memory_multiclick/small_frames_full \
--output-mask outputs/video_memory_multiclick/small_torch_masks_full.npy \
--output-video outputs/video_memory_multiclick/small_torch_overlay_full.mp4 \
--report outputs/benchmarks/video_memory_multiclick_small_torch_full.json \
--points 625 429 700 470 300 250 950 610 \
--labels 1 1 0 0
Other benchmark entry points:
uv run --extra torch-parity python scripts/benchmark_image_encoder.py --warmup 3 --runs 10
uv run python scripts/benchmark_prompt_segmenter.py --warmup 3 --runs 20
uv run python scripts/track_video_memory.py --frames 150 \
--report outputs/benchmarks/video_memory_latency_150f.json
Install
uv sync --python 3.14
Torch is only used for conversion and comparison fixtures:
uv sync --python 3.14 --extra torch-parity
Reference repositories may exist locally for development, but they are not runtime dependencies:
third_party/sam2
references/mlx-vlm
Quick API
import numpy as np
from mlx_sam import SAM2VideoPredictor
predictor = SAM2VideoPredictor.from_pretrained(
"avbiswas/sam2.1-hiera-small-mlx-fp32"
)
state = predictor.init_state("third_party/sam2/demo/data/gallery/01_dog.mp4")
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
state,
frame_idx=0,
obj_id=1,
points=np.array([[625.0, 429.0]], dtype=np.float32),
labels=np.array([1], dtype=np.int32),
)
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
# masks is a NumPy float32 array shaped O,1,H,W in original video resolution.
pass
Local checkpoint loading:
from mlx_sam import SAM2VideoPredictor
predictor = SAM2VideoPredictor(
checkpoint="checkpoints/sam2.1_hiera_small_image_segmenter.safetensors"
)
Image Segmentation
Run one prompted frame and write an overlay:
uv run python scripts/predict_image_mask.py \
--point 500 610 \
--output-video outputs/image_prompt_overlay.mp4 \
--output-mask outputs/image_prompt_mask.npy
Coordinates are in the resized 1024x1024 SAM input space.
Video Tracking
SAM2 memory tracker:
uv run python scripts/track_video_memory.py --frames 289 \
--point 500 610 \
--output-video outputs/dog_memory_overlay_full_v3.mp4 \
--output-mask outputs/dog_memory_masks_full_v3.npy \
--report outputs/benchmarks/dog_memory_latency_full_v3.json
The current memory tracker uses:
- first-frame point, box, or mask prompts
- SAM2 memory encoder and memory attention
- object pointers
- dynamic multimask fallback on unstable single-mask tracking outputs
- conditioning-frame memory plus SAM2-style frame-indexed temporal memory selection
- click-frame masks binarized before memory encoding, matching the official postprocessing path
- shared image features for multi-object tracking
- forward, backward, and bidirectional correction replay
Overlay Utility
Render masks onto a video:
uv run python scripts/overlay_masks.py \
--masks outputs/dog_memory_masks_full_v3.npy \
--output outputs/dog_memory_overlay_from_masks.mp4
The overlay script accepts .npy or .npz masks shaped T,H,W or T,1,H,W.
Synthetic overlays are only for writer smoke tests and require:
uv run python scripts/overlay_masks.py --synthetic-smoke-test
Convert Weights
Convert from Hugging Face:
uv run --extra torch-parity mlx-sam-convert \
--hf-id facebook/sam2.1-hiera-small \
--output-dir checkpoints
Supported source ids:
facebook/sam2.1-hiera-tiny
facebook/sam2.1-hiera-small
facebook/sam2.1-hiera-base-plus
facebook/sam2.1-hiera-large
Convert a local Torch checkpoint:
uv run --extra torch-parity mlx-sam-convert \
--checkpoint checkpoints/sam2.1_hiera_small.pt \
--model-id facebook/sam2.1-hiera-small \
--output checkpoints/sam2.1_hiera_small_image_segmenter.safetensors
The converted checkpoint includes the Hiera image encoder, FPN neck, prompt encoder, mask decoder, object pointer projection, memory encoder, and memory attention. Generated checkpoints are ignored by git.
The old script path remains as a compatibility wrapper:
uv run --extra torch-parity python scripts/convert_image_encoder_weights.py \
--checkpoint checkpoints/sam2.1_hiera_small.pt \
--model-id facebook/sam2.1-hiera-small
Feature Regression
Run MLX feature scenarios and compare against Torch fixtures:
uv run python scripts/run_feature_regression.py --frames 130
Regenerate official Torch fixtures first:
uv run python scripts/run_feature_regression.py --refresh-torch --frames 130
Compare existing outputs without rerunning MLX:
uv run python scripts/run_feature_regression.py --skip-mlx --frames 130
Covered scenarios:
multi_objectbox_promptnegative_clickscross_frame_correctionsbidirectional_middle
Current MLX-vs-Torch feature benchmark results on the 130-frame dog-gallery fixture:
| Scenario | Mean IoU | Presence |
|---|---|---|
multi_object |
0.973 |
260 / 260 |
box_prompt |
0.953 |
129 / 130 |
negative_clicks |
0.972 |
130 / 130 |
cross_frame_corrections |
0.974 |
130 / 130 |
bidirectional_middle |
0.924 |
128 / 130 |
On 02_cups.mp4, the same bidirectional benchmark with a center-cup prompt at
frame 120 is tighter:
bidirectional_middlemean IoU:0.979- Presence:
130 / 130 - Report:
outputs/cups_feature_benchmarks_130f/bidirectional_middle_mlx_vs_torch.json - Overlay:
outputs/cups_feature_benchmarks_130f/bidirectional_middle_mlx_overlay.mp4
Feature reports and inspection overlays are written under:
outputs/feature_benchmarks_130f/
Parity Validation
Parity is used as a guardrail, not as the main product surface. The current
full-video dog run compares MLX against the official Torch SAM2VideoPredictor
on all 289 frames:
- Mean mask IoU over all frames: about
0.977 - Median mask IoU on non-empty Torch frames: about
0.979 - Presence match:
289 / 289frames - Official Torch overlay:
outputs/torch_sam2_dog_overlay_full_twitter.mp4 - MLX overlay:
outputs/dog_memory_overlay_full_v3_twitter.mp4 - Comparison report:
outputs/benchmarks/dog_memory_mlx_vs_torch_full_v3.json
Image and prompt parity fixtures:
uv run --extra torch-parity python scripts/export_torch_image_embeddings.py --frames 2
uv run python scripts/compare_image_embeddings.py
uv run --extra torch-parity python scripts/export_torch_prompt_mask.py
uv run python scripts/compare_prompt_mask.py
Current low-level parity results:
- Image
vision_featuresmax abs error: about1.63e-05 - Prompted low-res masks max abs error: about
4.67e-05 - Prompted IoU max abs error: about
4.77e-07
Reports are written under:
outputs/parity/
Runtime Dependency Boundary
Default runtime should not include Torch:
uv sync --python 3.14
uv run python - <<'PY'
import importlib.util as u
print({m: bool(u.find_spec(m)) for m in ["torch", "torchvision", "hydra", "iopath", "mlx", "cv2"]})
PY
Expected:
torch=False, torchvision=False, hydra=False, iopath=False, mlx=True, cv2=True
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_sam-0.1.0.tar.gz.
File metadata
- Download URL: mlx_sam-0.1.0.tar.gz
- Upload date:
- Size: 24.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d0dce8e9a9a1a29f8c0f55db5ed11e53ad57fa0e1bccd1577f4b909aaddcfae9
|
|
| MD5 |
1712a1c5d4236eb9b2369886ee23699c
|
|
| BLAKE2b-256 |
5e0a1be74257044b32809f86cb8a515666a71836cc986c30a82bd5979ae53fdc
|
File details
Details for the file mlx_sam-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mlx_sam-0.1.0-py3-none-any.whl
- Upload date:
- Size: 31.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.2 {"installer":{"name":"uv","version":"0.10.2","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a70687318b50e4533d8eefef98be0a1d9cc6fdcddc9de8632f031202c6968ea
|
|
| MD5 |
7bb16a23475e5a39a96cd29ce6ef2a08
|
|
| BLAKE2b-256 |
5df2010f2f94b401bfdd5c889b59a644a032765578ba1ac4797af99049bf842b
|