Skip to main content

Implementation of popular vision models in Jax

Project description

Equimo: Modern Vision Models in JAX/Equinox

WARNING: This is a research library implementing recent computer vision models. The implementations are based on paper descriptions and may not be exact replicas of the original implementations. Use with caution in production environments.

Equimo (Equinox Image Models) provides JAX/Equinox implementations of recent computer vision models, currently focusing (but not limited to) on transformer and state-space architectures.

Features

  • Pure JAX/Equinox implementations
  • Focus on recent architectures (2023-2024 papers)
  • Modular design for easy experimentation
  • Extensive documentation and type hints
  • Experimental support for text embedding

Installation

From PyPI

pip install equimo

From Source

git clone https://github.com/clementpoiret/equimo.git
cd equimo
pip install -e .

Implemented Models

Beyond normal ViT (e.g., dinov2 or siglip), equimo proposes other SotA architectures:

Model Paper Year Status
FasterViT FasterViT: Fast Vision Transformers with Hierarchical Attention 2023
Castling-ViT Castling-ViT: Compressing Self-Attention via Switching Towards Linear-Angular Attention During Vision Transformer Inference 2023 Partial*
MLLA Mamba-like Linear Attention 2024
PartialFormer Efficient Vision Transformers with Partial Attention 2024
SHViT SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design 2024
VSSD VSSD: Vision Mamba with Non-Causal State Space Duality 2024
ReduceFormer ReduceFormer: Attention with Tensor Reduction by Summation 2024
LowFormer LowFormer: Hardware Efficient Design for Convolutional Transformer Backbones 2024
DINOv3 DINOv3 2025
FreeNet FreeNet: Liberating Depth-Wise Separable Operations for Building Faster Mobile Vision Architectures 2025

*: Only contains the Linear Angular Attention module. It is straight forward to build a ViT around it, but may require an additional __call__ kwarg to control the sparse_reg bool.

Basic Usage

import jax

import equimo.models as em

# Create a model (e.g. `faster_vit_0_224`)
key = jax.random.PRNGKey(0)
model = em.FasterViT(
    img_size=224,
    in_channels=3,
    dim=64,
    in_dim=64,
    depths=[2, 3, 6, 5],
    num_heads=[2, 4, 8, 16],
    hat=[False, False, True, False],
    window_size=[7, 7, 7, 7],
    ct_size=2,
    key=key,
)

# Generate random input
x = jax.random.normal(key, (3, 224, 224))

# Run inference
output = model(x, enable_dropout=False, key=key)

Working with text embeddings

Warning: this is experimental, it can break or change at any time

equimo.experimental.text has been added since v0.3.0. It allows working with both text and images. It is especially useful for models like SigLIP or TIPS, although only TIPS is currently supported.

Currently, text tokenization relies on tensorflow_text, install equimo with the text group such as uv add equimo[text].

Here is a very simple example of a 0-shot classification based on the comparison between text and image embeddings:

import jax
from einops import rearrange

from equimo.experimental.text import Tokenizer
from equimo.io import load_image, load_model
from equimo.utils import PCAVisualizer, normalize, plot_image_and_feature_map

# Random demo inputs
key = jax.random.PRNGKey(42)
image = load_image("./demo.jpg", size=448)
text = [
    "A baby discovering happiness",
    "A computer",
]

# Loading pretrained models
image_encoder = load_model("vit", "tips_vits14_hr")
text_encoder = load_model("experimental.textencoder", "tips_vits14_hr_text")

# Encoding text and image
ids, paddings = Tokenizer(identifier="sentencepiece_tips").tokenize(text, max_len=64)

text_embedding = normalize(
    jax.vmap(text_encoder, in_axes=(0, 0, None))(ids, paddings, key)
)
image_embedding = jax.vmap(image_encoder.norm)(image_encoder.features(image, key))
cls_token = normalize(image_embedding[0])
spatial_features = rearrange(
    image_embedding[2:], "(h w) d -> h w d", h=int(448 / 14), w=int(448 / 14)
)

# Getting probabilities based on Cosine Similarity
cos_sim = jax.nn.softmax(
    ((cls_token[None, :] @ text_embedding.T) / text_encoder.temperature), axis=-1
)

# Plot the results
label_idxs = jax.numpy.argmax(cos_sim, axis=-1)
cos_sim_max = jax.numpy.max(cos_sim, axis=-1)
label_predicted = text[label_idxs[0]]
similarity = cos_sim_max[0]
pca_obj = PCAVisualizer(spatial_features)
image_pca = pca_obj(spatial_features)

plot_image_and_feature_map(
    image.transpose(1, 2, 0),
    image_pca,
    "./out.png",
    "Input Image",
    f"{label_predicted}, prob: {similarity * 100:.2f}%",
)

Resulting in such a wonderful result:

Output of TIPS 0-shot classification

Saving and Loading Models

Equimo provides utilities for saving models locally and loading pre-trained models from the official repository.

Saving Models Locally

from pathlib import Path
from equimo.io import save_model

# Save model with compression (creates .tar.lz4 file)
save_model(
    Path("path/to/save/model"),
    model,  # can be any model you created using Equimo
    model_config,
    torch_hub_cfg,  # This can be an empty list, it's mainly to keep track of where are the weights coming
    compression=True
)

# Save model without compression (creates directory)
save_model(
    Path("path/to/save/model"),
    model,
    model_config,
    torch_hub_cfg,
    compression=False
)

Loading Models

from equimo.io import load_model

# Load a pre-trained model from the official repository
model = load_model(cls="vit", identifier="dinov2_vits14_reg")

# Load a local model (compressed)
model = load_model(cls="vit", path=Path("path/to/model.tar.lz4"))

# Load a local model (uncompressed directory)
model = load_model(cls="vit", path=Path("path/to/model/"))

Parameters passed to models can be overridden such as:

model = load_model(
    cls="vit",
    identifier="siglip2_vitb16_256",
    dynamic_img_size=True,  # passed to the VisionTransformer class
)

List of pretrained models

The following models have pretrained weights available in Equimo:

Model identifiers allow downloading from equimo's repository on huggingface

Identifiers are filenames without the extensions, such as:

  • dinov2_vitb14
  • dinov2_vits14_reg
  • siglip2_vitl16_512
  • siglip2_vitso400m16_384
  • tips_vitg14_lr

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use Equimo in your research, please cite:

@software{equimo2024,
  author = {Clément POIRET},
  title = {Equimo: Modern Vision Models in JAX/Equinox},
  year = {2024},
  publisher = {GitHub},
  url = {https://github.com/clementpoiret/equimo}
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

equimo-0.5.0a10.tar.gz (84.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

equimo-0.5.0a10-py3-none-any.whl (103.9 kB view details)

Uploaded Python 3

File details

Details for the file equimo-0.5.0a10.tar.gz.

File metadata

  • Download URL: equimo-0.5.0a10.tar.gz
  • Upload date:
  • Size: 84.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.19

File hashes

Hashes for equimo-0.5.0a10.tar.gz
Algorithm Hash digest
SHA256 04e37f7c1d099fc9309aec152677591fb45667af1f4195685c35690fa9419bde
MD5 88e11c98437b5d8c8bf58a2d4badf606
BLAKE2b-256 01cb83352d9b81736d8bd0a01585d2c8b5d1d1c464f95f0779536efaddc37b76

See more details on using hashes here.

File details

Details for the file equimo-0.5.0a10-py3-none-any.whl.

File metadata

  • Download URL: equimo-0.5.0a10-py3-none-any.whl
  • Upload date:
  • Size: 103.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.19

File hashes

Hashes for equimo-0.5.0a10-py3-none-any.whl
Algorithm Hash digest
SHA256 071f4612ffcae1902806a37e3d64009194094a568a6d2d310e0fb13190c47960
MD5 18acb4c2203f75d98b3f4c42f9f4e69f
BLAKE2b-256 fec1d296ad12c6d6ae95961f7902f99f19cad26127b15560e0c1f5fd1ba7712c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page