Skip to main content

Implementations of popular vision models in Jax (Flax NNX)

Project description

Jimmy Vision

Jimmy Vision is a Jax-based library that provides implement computer vision models. It's designed to be flexible, efficient, and easy to use for researchers and practitioners in the field of deep learning.

[!WARNING] Jimmy is not yet ready for production use. It is a work in progress, intended for experimentations.

Features

  • Implementation of DinoV2 (Vision Transformer) models
  • Implementation of MambaVision models
  • Support for loading pre-trained weights from PyTorch models (DinoV2 only)
  • Flexible model configuration and customization
  • Efficient Jax-based computations

Installation

Install using PyPI

# cpu
pip install jimmy-vision

# cuda12
pip install jimmy-vision[cuda12]

Cloning the repo

You can either use poetry, nix, or devenv:

git clone git@github.com:clementpoiret/jimmy.git
cd jimmy

# either
nix develop --impure

# or
poetry install -E cuda12

Quick Start

Here's a quick example of how to use Jimmy to load a pre-trained DinoV2 model:

import jax
import jax.numpy as jnp
from flax import nnx

from jimmy.models import DINOV2_VITS14, load_model

# Initialize random number generator
rngs = nnx.Rngs(42)

# Load the model
model = load_model(
    DINOV2_VITS14,
    rngs=rngs,
    pretrained=True,
    url=
    "https://huggingface.co/poiretclement/dinov2_jax/resolve/main/dinov2_vits14.jim",
)

# Create a random input
key = rngs.params()
x = jax.random.normal(key, (1, 518, 518, 3))

# Run inference
output = model(x)
print(output.shape)  # 1, 1370, 384

Models

Jimmy currently supports the following models:

  • DinoV2 (various sizes: ViT-S/14, ViT-B/14, ViT-L/14, ViT-G/14)
  • MambaVision (coming soon)

To load a specific model, you can use the load_model function with the appropriate configuration:

from jimmy.models import load_model, DINOV2_VITB14

model = load_model(DINOV2_VITB14, rngs=rngs)

Custom Models

You can also create custom models by modifying the existing configurations:

custom_config = {
    "name": "custom_dinov2",
    "class": "dinov2",
    "config": {
        "num_heads": 8,
        "embed_dim": 512,
        "mlp_ratio": 4,
        "patch_size": 16,
        "depth": 8,
        "img_size": 224,
        "qkv_bias": True,
    }
}

custom_model = load_model(custom_config, rngs=rngs)

Advanced Training Example

Here is a toy example to train Mlla using a MESA loss term based on an Exponential Moving Average of the weights.

from flax import nnx
from jimmy.models.mlla import Mlla
from jimmy.models.emamodel import EmaModel
from optax import adam
from optax.losses import softmax_cross_entropy

x, y = ...
criterion = softmax_cross_entropy

# Defines the model
model = Mlla(
    num_classes=1000,
    depths=[2, 4, 12, 4],
    patch_size=4,
    in_features=3,
    embed_dim=96,
    num_heads=[2, 4, 8, 16],
    layer_window_sizes=[-1, -1, -1, -1],
    rngs=rngs,
)
model.train()

# Defines the wrapper tracking the EMA of the weights
ema_model = EmaModel(model)
ema_model.model.eval()  # To disable dropouts

optimizer = nnx.Optimizer(model, adam(1e-3))


# Core training fn
@nnx.jit
def train(model, ema_model, optimizer, x, y)
    def loss_fn(model, ema_model):
        y_pred = model(x)
        ema_outputs = ema_model(x)

        # Actually you may want to setup a warmup phase, or start MESA after X epochs
        loss = criterion(y_pred, y) + 0.3 * criterion(y_pred, ema_outputs)

    loss, grads = nnx.value_and_grad(loss_fn)(model, ema_model)

    optimizer.update(grads)
    params = nnx.state(model, nnx.Param)

    # Updates the moving average
    ema_model.update(params)

    return loss

for i in range(8):
    loss = train(model, ema_model, optimizer, x, y)
    print(i, loss)

Contributing

Contributions to Jimmy are welcome! Please feel free to submit a Pull Request.

License

Jimmy is released under the MIT License. See the LICENSE file for more details.

References

This library drawed inspirations from:

Citation

If you use Jimmy in your research, please cite it as follows:

@software{jimmy2024,
  author = {Clément POIRET},
  title = {Jimmy},
  year = {2024},
  url = {https://github.com/clementpoiret/jimmy},
}

For any questions or issues, please open an issue on the GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jimmy_vision-0.0.7.tar.gz (158.7 kB view details)

Uploaded Source

Built Distribution

jimmy_vision-0.0.7-py3-none-any.whl (50.8 kB view details)

Uploaded Python 3

File details

Details for the file jimmy_vision-0.0.7.tar.gz.

File metadata

  • Download URL: jimmy_vision-0.0.7.tar.gz
  • Upload date:
  • Size: 158.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for jimmy_vision-0.0.7.tar.gz
Algorithm Hash digest
SHA256 6c65b17b6d9078502681123047ac575cfe41d95ff46cd4c033c6871668b2aa73
MD5 03163e023b774bb6137501bf6746215c
BLAKE2b-256 4b022fd3bc7d9f4d529a97c19e48565bbddf6a3826f670c5a1f9836746985447

See more details on using hashes here.

File details

Details for the file jimmy_vision-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: jimmy_vision-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 50.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.4

File hashes

Hashes for jimmy_vision-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 3020df1e422d0985d394dfe4da05aabbb51dee2f43a8fc2ed314b26973364011
MD5 cffd6fa789eee9d2f39850a32d9d8421
BLAKE2b-256 20a8fa557c43f1780e8902d41dd81991dceb7468978320baafbbad741a1ae1a3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page