Skip to main content

Multi-Resolution Vision Transformers for Learning Across Scales in Microscopy

Project description

logo

MuViT: Multi-Resolution Vision Transformers for Learning Across Scales in Microscopy

Official implementation of MuViT (CVPR 2026), a vision transformer-based architecture designed to process gigapixel microscopy images by jointly modelling multiple scales with a single encoder. For technical details please check the preprint (coming very soon!).

This repository contains the implementation of the MuViT architecture, along with the multi-resolution Masked Autoencoder (MAE) pre-training framework.

Overview

Fig overview

Modern microscopy yields gigapixel images capturing structures with hierarchical organization spanning from individual cell morphology to broad tissue architecture. A central challenge in analyzing those images is that models must trade off effective context against spatial resolution. Standard CNNs or ViTs typically operate on single-resolution crops, with hierarchical feature pyramids being built from a single view.

To tackle this, MuViT is designed to jointly process FOVs of the same image at different physical resolutions within a unified encoder. This is achieved by jointly feeding the different scales to the model and adding consistent world-coordinate RoPE, a simple yet effective mechanism which ensures that the same physical location receives the same positional encoding across scales. This enables the attention mechanism to work across different scales, allowing integration of wide-field context (e.g. anatomical) with high-resolution detail (e.g. cellular) for solving dense computer vision tasks, like segmentation.

Furthermore, MuViT extends the Masked Autoencoder (MAE) pre-training framework to a multi-resolution setting to learn powerful representations from unlabeled large-scale data. This produces highly informative, scale-consistent features that substantially accelerate convergence and improve sample efficiency on downstream tasks.

Installation

Simply clone the repository, create a new Python environment (with conda or alike) and install the repository in editable mode:

mamba create -y -n muvit python=3.12
git clone git@github.com:weigertlab/muvit.git
pip install -e ./muvit

Usage

Creating a MuViT dataset

All PyTorch datasets to be used for MuViT should inherit from muvit.data.MuViTDataset, which will run sanity checks on e.g. the output format to ensure consistency. It requires implementing the following methods and properties (check the implementation of the MuViTDataset class for more details):

from muvit.data import MuViTDataset

class MyMuViTDataset(MuViTDataset):
    def __init__(self):
        pass

    def __len__(self) -> int:
        # number of samples in the dataset
        return 42 # change accordingly

    @property
    def n_channels(self) -> int:
        # number of channels in the input images
        return 1 # change accordingly

    @property
    def levels(self) -> Tuple[int, ...]:
        # return resolution levels (in ascending order)
        return (1,8,32) # change accordingly

    @property
    def ndim(self) -> int:
        # returns number of spatial dimensions
        return 2 # change accordingly

    def __getitem__(self, idx) -> dict:
        # should return a dictionary like
        return {
            "img": img, # torch tensor of shape (L,C,Y,X)
            "bbox": bbox, # torch tensor of shape (L,2,Nd) where Nd is the number of spatial dimensions (e.g. 2)
        } 

Bounding box format

The bbox (bounding box) tensor defines the exact physical extent (field of view) of each image crop within a shared world-coordinate system, which we define as the highest resolution pixel space. For a single dataset sample, it must have the shape $(L, 2, N_d)$, where $L$ is the number of resolution levels and $Nd$ is the number of spatial dimensions (e.g., 2). The second dimension, always of size 2 represents the boundaries of the crop: index 0 contains the minimum coordinates (top-left, i.e., [y_min, x_min]) and index 1 contains the maximum coordinates (bottom-right, i.e., [y_max, x_max]). Providing them as accurately as possible is crucial, as MuViT relies on them to geometrically align the different resolutions.

Multiscale MAE pre-training

In order to pre-train an MAE model on your created dataset, you can simply instantiate the MuViTMAE2d class and pass the dataloaders to its .fit method. Most of the parameters are customizable (e.g. number of layers, patch size, etc.). For more information please check the implementation of the MuViTMAE2d class. We use PyTorch Lightning to handle the training logic. For example:

import torch

from muvit.data import MuViTDataset
from muvit.mae import MuViTMAE2d

class MyMuViTDataset(MuViTDataset):
    # implement the dataset as shown above
    pass

train_ds = MyMuViTDataset(args1)
val_ds = MyMuViTDataset(args2)

model = MuViTMAE2d(
    in_channels=train_ds.n_channels,
    levels=train_ds.levels,
    patch_size=8,
    num_layers=12,
    num_layers_decoder=4,
    ... # other parameters
)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)
model.fit(train_dl, val_dl, output="/path/to/pretrained", num_epochs=100, ...)

Using a pre-trained encoder

After pre-training the MAE model, you can use the encoder for downstream tasks or feature extraction. To get the encoder from the MAE pre-trained model, you can simply load it using our helper function and access it via the encoder attribute:

from muvit.mae import MuViTMAE2d

encoder = MuViTMAE2d.from_folder("/path/to/pretrained").encoder

which returns a MuViTEncoder PyTorch module that is pluggable into any downstream pipeline. The encoder expects an input tensor of shape $(B,L,C,Y,X)$ (where $L$ denotes the number of resolution levels) along with the world coordinates, which are given as a "bounding-box" tensor of shape $(B,L,2,2)$ (for 2D). Note that not giving an explicit bounding box might cause undefined behaviour. The output of an encoder is a tensor of shape $(B,N,D)$ where $N$ is the number of tokens and $D$ is the embedding dimension.

The method compute_features() of an encoder will run a forward pass on a given multi-scale tensor and corresponding bounding boxes and return the features in a spatially structured format $(B,L,D,H',W')$ where $H'=\frac{H}{P}$ and $W'=\frac{W}{P}$, with $P$ being the patch size.

Citation

If you use this code for your research, please cite the following article:

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muvit-0.1.0.tar.gz (4.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

muvit-0.1.0-py3-none-any.whl (4.6 MB view details)

Uploaded Python 3

File details

Details for the file muvit-0.1.0.tar.gz.

File metadata

  • Download URL: muvit-0.1.0.tar.gz
  • Upload date:
  • Size: 4.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for muvit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4808ab05d0c154cfeed89c3d05814e348c0952abba2229f8f813b0130a2b80a7
MD5 c1416ce56eabc11d13ffae308eb11002
BLAKE2b-256 70f1cb09e8c66999382908317875688411d580aead78f7639089cce77371ac9b

See more details on using hashes here.

Provenance

The following attestation bundles were made for muvit-0.1.0.tar.gz:

Publisher: release.yml on weigertlab/muvit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file muvit-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: muvit-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 4.6 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for muvit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 86e4a0d9f5213f6ff4d1ade97a353cf15577a9ec80f82060193e42670bcac29a
MD5 3729927b4a81bfbc8e4d2b98bf840fa0
BLAKE2b-256 9a291f45847e6dc70b0491c82a4b3dad2fa1769eb4dd4ec96bfa34bd89b0f05a

See more details on using hashes here.

Provenance

The following attestation bundles were made for muvit-0.1.0-py3-none-any.whl:

Publisher: release.yml on weigertlab/muvit

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page