Skip to main content

CanViT inference on Apple Silicon via MLX

Project description

CanViT-MLX

MLX implementation of CanViT, the Canvas Vision Transformer, for native Apple Silicon inference.

CanViT: Toward Active-Vision Foundation Models (arXiv:2603.22570)

Install

pip install canvit-mlx

Or from source:

git clone https://github.com/yberreby/CanViT-MLX.git
cd CanViT-MLX
uv sync

Quickstart

import mlx.core as mx
from canvit_mlx import load_from_hf_hub, load_and_preprocess, Viewpoint, extract_glimpse_at_viewpoint

model = load_from_hf_hub("canvit/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02-mlx")
image = load_and_preprocess("test_data/Cat03.jpg", target_size=512)

state = model.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = extract_glimpse_at_viewpoint(image, vp, glimpse_size_px=128)
out = model(glimpse, state, vp)
mx.eval(out.state.canvas, out.state.recurrent_cls, out.local_patches)

# Canvas spatial features (linearly decodable into dense predictions)
canvas_spatial = model.get_spatial(out.state.canvas)  # [1, 1024, 1024]
print(canvas_spatial.shape)

Classification

from pathlib import Path
from canvit_mlx import CanViTForImageClassification, Viewpoint, extract_glimpse_at_viewpoint, load_and_preprocess

clf = CanViTForImageClassification.from_pretrained_with_probe(
    pretrained_weights=Path("weights/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02.safetensors"),
    pretrained_config=Path("weights/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02.json"),
    probe_weights=Path("path/to/probe.safetensors"),
)

image = load_and_preprocess("test_data/Cat03.jpg", target_size=512)
state = clf.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = extract_glimpse_at_viewpoint(image, vp, glimpse_size_px=128)
logits, new_state = clf(glimpse, state, vp)

Demos

uv run --group demos python demos/basic.py
uv run --group demos python demos/basic.py --image test_data/Cat03.jpg --canvas-grid 64

Converting weights

Convert a PyTorch checkpoint from HuggingFace Hub to MLX format:

uv run python convert.py
uv run python convert.py --verify  # includes PT vs MLX numerical comparison

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

canvit_mlx-0.1.0.tar.gz (445.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

canvit_mlx-0.1.0-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file canvit_mlx-0.1.0.tar.gz.

File metadata

  • Download URL: canvit_mlx-0.1.0.tar.gz
  • Upload date:
  • Size: 445.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_mlx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c2624db0bedc37c9c9180369af52e11421e3ce1af6aef9a1dfb7c967ce7a1bee
MD5 6935eaebd413495d63dc03d7f568111f
BLAKE2b-256 7f6f204c3031765b3755d5040200e558dc584e7b4390eb3134978bfc7e7cd061

See more details on using hashes here.

File details

Details for the file canvit_mlx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: canvit_mlx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_mlx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d6a1b9aa1e237185da627d2eaa400282b7e0f6223a696422cc3be203f8bb58c2
MD5 81117fe54d0c044c2a8dd77f4b44756f
BLAKE2b-256 e8a0fe36d76f8dfdceb0c7afa0cee2245a5f70c1788cd03940721ee15a2f5ba7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page