CanViT (Canvas Vision Transformer) -- JAX / Flax NNX
Project description
CanViT-NNX
Experimental. Port of CanViT to JAX using Flax NNX. Reference implementation: CanViT-PyTorch. May break at any time.
JAX / Flax NNX implementation of the Canvas Vision Transformer (CanViT).
Checkpoints: canvit/canvit-nnx on HuggingFace.
Install
uv add canvit-nnx
We recommend uv for dependency management.
Usage
Classification
from canvit_nnx import CanViTForImageClassification, Viewpoint, sample_at_viewpoint
clf = CanViTForImageClassification.from_pretrained(
"canvit/canvitb16-add-vpe-finetune-g128px-s512px-in1k-2026-04-06-nnx"
)
state = clf.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
logits, state = clf(glimpse, state, vp)
Pretrained backbone (dense features)
import jax.numpy as jnp
from canvit_nnx import from_pretrained, Viewpoint, sample_at_viewpoint
model = from_pretrained(
"canvit/canvitb16-add-vpe-pretrain-g128px-s512px-in21k-dv3b16-2026-02-02-nnx"
)
state = model.init_state(batch_size=1, canvas_grid_size=32)
vp = Viewpoint.full_scene(batch_size=1)
glimpse = sample_at_viewpoint(spatial=image, viewpoint=vp, glimpse_size_px=128)
out = model(glimpse, state, vp)
state = out.state
# Canvas features should be layernormed before downstream use (PCA, probing, etc.)
canvas = model.get_spatial(state.canvas)
mean = canvas.mean(axis=-1, keepdims=True)
canvas = (canvas - mean) / jnp.sqrt(canvas.var(axis=-1, keepdims=True) + 1e-5)
Structure
canvit_nnx/
model.py — CanViT architecture (self-contained, no external model deps)
classification.py — CanViTForImageClassification
viewpoint.py — sample_at_viewpoint (bilinear glimpse extraction)
hub.py — HuggingFace Hub loading (JAX-native safetensors, no torch)
scripts/
demo.py — classification demo
verify.py — cross-framework numerical verification
convert_from_pytorch.py — one-time PyTorch -> JAX checkpoint conversion
push_to_hub.py — convert + push to HuggingFace
Tests
uv run pytest tests/ -v
Verify against PyTorch
uv run --extra verify python scripts/verify.py
uv run --extra verify python scripts/verify.py --image path/to/image.jpg
Demo
uv run --extra demo python scripts/demo.py --image path/to/image.jpg
Citation
@article{berreby2026canvit,
title={CanViT: Toward Active-Vision Foundation Models},
author={Berreby, Yoha{\"i}-Eliel and Du, Sabrina and Durand, Audrey and Krishna, B. Suresh},
year={2026},
eprint={2603.22570},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file canvit_nnx-0.1.0.tar.gz.
File metadata
- Download URL: canvit_nnx-0.1.0.tar.gz
- Upload date:
- Size: 11.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ea9abc7fff5999f67e106a17d4c31cf5c748ae90af904e21228c4e17582b3f1b
|
|
| MD5 |
4cd259f463460a12ea6fbb1fdf6c25f3
|
|
| BLAKE2b-256 |
996dc4b3eade9cb87578799c3944813d8e699a13ce8fd55a0acdc1f891eedab2
|
File details
Details for the file canvit_nnx-0.1.0-py3-none-any.whl.
File metadata
- Download URL: canvit_nnx-0.1.0-py3-none-any.whl
- Upload date:
- Size: 11.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c435a5bdbdd8d1cfd4516676ffe32ee22e3679e174bf7a9c067a1fada6546188
|
|
| MD5 |
38dfc4dbc9788a8c67f4827a6060d9d6
|
|
| BLAKE2b-256 |
a68ba85f16272fe9a573dfcc617fe7e32523d191eb3d7dd6bb04b77220fd4117
|