Skip to main content

CanViT inference on Apple Silicon via MLX

Project description

CanViT-MLX

Experimental. MLX implementation of CanViT, the Canvas Vision Transformer, for Apple Silicon. Reference implementation: CanViT-PyTorch. May break at any time.

CanViT: Toward Active-Vision Foundation Models (arXiv:2603.22570)

Install

uv add "canvit-mlx[hub]"

Quickstart

import mlx.core as mx
from canvit_mlx import load_from_hf_hub, load_and_preprocess, Viewpoint, extract_glimpse_at_viewpoint

model = load_from_hf_hub("canvit/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02-mlx")
image = load_and_preprocess("test_data/Cat03.jpg", target_size=512)

state = model.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = extract_glimpse_at_viewpoint(image, vp, glimpse_size_px=128)
out = model(glimpse, state, vp)
mx.eval(out.state.canvas, out.state.recurrent_cls, out.local_patches)

# Canvas spatial features (linearly decodable into dense predictions)
canvas_spatial = model.get_spatial(out.state.canvas)  # [1, 1024, 1024]
print(canvas_spatial.shape)

Classification

from pathlib import Path
from canvit_mlx import CanViTForImageClassification, Viewpoint, extract_glimpse_at_viewpoint, load_and_preprocess

clf = CanViTForImageClassification.from_pretrained_with_probe(
    pretrained_weights=Path("weights/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02.safetensors"),
    pretrained_config=Path("weights/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02.json"),
    probe_weights=Path("path/to/probe.safetensors"),
)

image = load_and_preprocess("test_data/Cat03.jpg", target_size=512)
state = clf.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = extract_glimpse_at_viewpoint(image, vp, glimpse_size_px=128)
logits, new_state = clf(glimpse, state, vp)

Demos

uv run --group demos python demos/basic.py
uv run --group demos python demos/basic.py --image test_data/Cat03.jpg --canvas-grid 64

Converting weights

Convert a PyTorch checkpoint from HuggingFace Hub to MLX format:

uv run python convert.py
uv run python convert.py --verify  # includes PT vs MLX numerical comparison

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

canvit_mlx-0.1.2.tar.gz (433.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

canvit_mlx-0.1.2-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file canvit_mlx-0.1.2.tar.gz.

File metadata

  • Download URL: canvit_mlx-0.1.2.tar.gz
  • Upload date:
  • Size: 433.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_mlx-0.1.2.tar.gz
Algorithm Hash digest
SHA256 b28c0f8dcbb50207696d2bf24fbd57ec976e59481e8f08b3f83c749467d815de
MD5 1df3139ca53bc2c0479d35bbc77a7fa6
BLAKE2b-256 d4553c5cbfefd2ab134dacc1ee7bf46441efaaa4ecc9b13984a270d93a2a0e92

See more details on using hashes here.

File details

Details for the file canvit_mlx-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: canvit_mlx-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 16.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_mlx-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 37364ad27e263ca0fe8af279bb66bbd435ea8a0b475ab7ba4d239a3121c3ff7e
MD5 96820eaa32719f99adbe30f324fa4d88
BLAKE2b-256 1f5f33fbd721b82e87c7e467ca48ecfee1c54ce6a3af0816cfecdf6c014effc8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page