Implementations of popular vision models in Jax (Flax NNX)
Project description
Jimmy Vision
Jimmy Vision is a Jax-based library that provides implement computer vision models. It's designed to be flexible, efficient, and easy to use for researchers and practitioners in the field of deep learning.
[!WARNING] Jimmy is not yet ready for production use. It is a work in progress, intended for experimentations.
Features
- Implementation of DinoV2 (Vision Transformer) models
- Implementation of MambaVision models
- Support for loading pre-trained weights from PyTorch models (DinoV2 only)
- Flexible model configuration and customization
- Efficient Jax-based computations
Installation
Install using PyPI
# cpu
pip install jimmy-vision
# cuda12
pip install jimmy-vision[cuda12]
Cloning the repo
You can either use poetry, nix, or devenv:
git clone git@github.com:clementpoiret/jimmy.git
cd jimmy
# either
nix develop --impure
# or
poetry install -E cuda12
Quick Start
Here's a quick example of how to use Jimmy to load a pre-trained DinoV2 model:
import jax
import jax.numpy as jnp
from flax import nnx
from jimmy.models import DINOV2_VITS14, load_model
# Initialize random number generator
rngs = nnx.Rngs(42)
# Load the model
model = load_model(
DINOV2_VITS14,
rngs=rngs,
pretrained=True,
url=
"https://huggingface.co/poiretclement/dinov2_jax/resolve/main/dinov2_vits14.jim",
)
# Create a random input
key = rngs.params()
x = jax.random.normal(key, (1, 518, 518, 3))
# Run inference
output = model(x)
print(output.shape) # 1, 1370, 384
Models
Jimmy currently supports the following models:
- DinoV2 (various sizes: ViT-S/14, ViT-B/14, ViT-L/14, ViT-G/14)
- MambaVision (coming soon)
To load a specific model, you can use the load_model
function with the appropriate
configuration:
from jimmy.models import load_model, DINOV2_VITB14
model = load_model(DINOV2_VITB14, rngs=rngs)
Custom Models
You can also create custom models by modifying the existing configurations:
custom_config = {
"name": "custom_dinov2",
"class": "dinov2",
"config": {
"num_heads": 8,
"embed_dim": 512,
"mlp_ratio": 4,
"patch_size": 16,
"depth": 8,
"img_size": 224,
"qkv_bias": True,
}
}
custom_model = load_model(custom_config, rngs=rngs)
Advanced Training Example
Here is a toy example to train Mlla using a MESA loss term based on an Exponential Moving Average of the weights.
from flax import nnx
from jimmy.models.mlla import Mlla
from jimmy.models.emamodel import EmaModel
from optax import adam
from optax.losses import softmax_cross_entropy
x, y = ...
criterion = softmax_cross_entropy
# Defines the model
model = Mlla(
num_classes=1000,
depths=[2, 4, 12, 4],
patch_size=4,
in_features=3,
embed_dim=96,
num_heads=[2, 4, 8, 16],
layer_window_sizes=[-1, -1, -1, -1],
rngs=rngs,
)
model.train()
# Defines the wrapper tracking the EMA of the weights
ema_model = EmaModel(model)
ema_model.model.eval() # To disable dropouts
optimizer = nnx.Optimizer(model, adam(1e-3))
# Core training fn
@nnx.jit
def train(model, ema_model, optimizer, x, y)
def loss_fn(model, ema_model):
y_pred = model(x)
ema_outputs = ema_model(x)
# Actually you may want to setup a warmup phase, or start MESA after X epochs
loss = criterion(y_pred, y) + 0.3 * criterion(y_pred, ema_outputs)
loss, grads = nnx.value_and_grad(loss_fn)(model, ema_model)
optimizer.update(grads)
params = nnx.state(model, nnx.Param)
# Updates the moving average
ema_model.update(params)
return loss
for i in range(8):
loss = train(model, ema_model, optimizer, x, y)
print(i, loss)
Contributing
Contributions to Jimmy are welcome! Please feel free to submit a Pull Request.
License
Jimmy is released under the MIT License. See the LICENSE file for more details.
References
This library drawed inspirations from:
Citation
If you use Jimmy in your research, please cite it as follows:
@software{jimmy2024,
author = {Clément POIRET},
title = {Jimmy},
year = {2024},
url = {https://github.com/clementpoiret/jimmy},
}
For any questions or issues, please open an issue on the GitHub repository.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jimmy_vision-0.0.7.tar.gz
.
File metadata
- Download URL: jimmy_vision-0.0.7.tar.gz
- Upload date:
- Size: 158.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6c65b17b6d9078502681123047ac575cfe41d95ff46cd4c033c6871668b2aa73 |
|
MD5 | 03163e023b774bb6137501bf6746215c |
|
BLAKE2b-256 | 4b022fd3bc7d9f4d529a97c19e48565bbddf6a3826f670c5a1f9836746985447 |
File details
Details for the file jimmy_vision-0.0.7-py3-none-any.whl
.
File metadata
- Download URL: jimmy_vision-0.0.7-py3-none-any.whl
- Upload date:
- Size: 50.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3020df1e422d0985d394dfe4da05aabbb51dee2f43a8fc2ed314b26973364011 |
|
MD5 | cffd6fa789eee9d2f39850a32d9d8421 |
|
BLAKE2b-256 | 20a8fa557c43f1780e8902d41dd81991dceb7468978320baafbbad741a1ae1a3 |