Skip to main content

CanViT (Canvas Vision Transformer) -- JAX / Flax NNX

Project description

CanViT-NNX

Experimental. Port of CanViT to JAX using Flax NNX. Reference implementation: CanViT-PyTorch. May break at any time.

JAX / Flax NNX implementation of the Canvas Vision Transformer (CanViT).

Checkpoints: canvit/canvit-nnx on HuggingFace.

Install

uv add canvit-nnx

We recommend uv for dependency management.

Usage

Classification

from canvit_nnx import CanViTForImageClassification, Viewpoint, sample_at_viewpoint

clf = CanViTForImageClassification.from_pretrained(
    "canvit/canvitb16-add-vpe-finetune-g128px-s512px-in1k-2026-04-06-nnx"
)
state = clf.init_state(batch_size=1, canvas_grid_size=32)

vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
logits, state = clf(glimpse, state, vp)

Pretrained backbone (dense features)

import jax.numpy as jnp
from canvit_nnx import from_pretrained, Viewpoint, sample_at_viewpoint

model = from_pretrained(
    "canvit/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02-nnx"
)
state = model.init_state(batch_size=1, canvas_grid_size=32)

vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
out = model(glimpse, state, vp)
state = out.state

# Canvas features should be layernormed before downstream use (PCA, probing, etc.)
canvas = model.get_spatial(state.canvas)
mean = canvas.mean(axis=-1, keepdims=True)
canvas = (canvas - mean) / jnp.sqrt(canvas.var(axis=-1, keepdims=True) + 1e-5)

Structure

canvit_nnx/
  model.py            — CanViT architecture (self-contained, no external model deps)
  classification.py   — CanViTForImageClassification
  viewpoint.py        — sample_at_viewpoint (bilinear glimpse extraction)
  hub.py              — HuggingFace Hub loading (JAX-native safetensors, no torch)

scripts/
  demo.py                  — classification demo
  verify.py                — cross-framework numerical verification
  convert_from_pytorch.py  — one-time PyTorch -> JAX checkpoint conversion
  push_to_hub.py           — convert + push to HuggingFace

Tests

uv run pytest tests/ -v

Verify against PyTorch

uv run --extra verify python scripts/verify.py
uv run --extra verify python scripts/verify.py --image path/to/image.jpg

Demo

uv run --extra demo python scripts/demo.py --image path/to/image.jpg

Citation

@article{berreby2026canvit,
  title={CanViT: Toward Active-Vision Foundation Models},
  author={Berreby, Yoha{\"i}-Eliel and Du, Sabrina and Durand, Audrey and Krishna, B. Suresh},
  year={2026},
  eprint={2603.22570},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

canvit_nnx-0.1.0.tar.gz (11.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

canvit_nnx-0.1.0-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file canvit_nnx-0.1.0.tar.gz.

File metadata

  • Download URL: canvit_nnx-0.1.0.tar.gz
  • Upload date:
  • Size: 11.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_nnx-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ea9abc7fff5999f67e106a17d4c31cf5c748ae90af904e21228c4e17582b3f1b
MD5 4cd259f463460a12ea6fbb1fdf6c25f3
BLAKE2b-256 996dc4b3eade9cb87578799c3944813d8e699a13ce8fd55a0acdc1f891eedab2

See more details on using hashes here.

File details

Details for the file canvit_nnx-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: canvit_nnx-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for canvit_nnx-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c435a5bdbdd8d1cfd4516676ffe32ee22e3679e174bf7a9c067a1fada6546188
MD5 38dfc4dbc9788a8c67f4827a6060d9d6
BLAKE2b-256 a68ba85f16272fe9a573dfcc617fe7e32523d191eb3d7dd6bb04b77220fd4117

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page