Apple MLX image models library
Project description
mlx-image
Image models based on Apple MLX framework for Apple Silicon machines.
Why?
Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.
This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models are not trained again.
How to install
pip install mlx-image
Models
Model weights are available on the mlx-vision
community on HuggingFace.
To load a model with pre-trained weights:
from mlxim.model import create_model
# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF
# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")
To list all available models:
from mlxim.model import list_models
list_models()
Supported models
List of all models available in mlx-image
:
- ResNet: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2
- ViT:
-
supervised: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e
-
DINO v1: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino
-
DINO v2: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2
-
- Swin: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256
- RegNet: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf
Warning: The
regnet_y_128gf
model couldn't be tested due to computational limitations.
ImageNet-1K Results
Go to results-imagenet-1k.csv to check every model converted to mlx-image
and its performance on ImageNet-1K with different settings.
TL;DR performance is comparable to the original models from PyTorch implementations.
Similarity to PyTorch and other familiar tools
mlx-image
tries to be as close as possible to PyTorch:
DataLoader
-> you can define your owncollate_fn
and also usenum_workers
to speed up data loadingDataset
->mlx-image
already supportsLabelFolderDataset
(the good and old PyTorchImageFolder
) andFolderDataset
(a generic folder with images in it)ModelCheckpoint
-> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping
Training
Training is similar to PyTorch. Here's an example of how to train a model:
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader
train_dataset = LabelFolderDataset(
root_dir="path/to/train",
class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)
def train_step(model, inputs, targets):
logits = model(inputs)
loss = mx.mean(nn.losses.cross_entropy(logits, target))
return loss
model.train()
for epoch in range(10):
for batch in train_loader:
x, target = batch
train_step_fn = nn.value_and_grad(model, train_step)
loss, grads = train_step_fn(x, target)
optimizer.update(model, grads)
mx.eval(model.state, optimizer.state)
Validation
The validation.py
script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.
I use the configuration file config/validation.yaml
to set the parameters for the validation script.
You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this link.
Contributing
This is a work in progress, so any help is appreciated.
I am working on it in my spare time, so I can't guarantee frequent updates.
If you love coding and want to contribute, follow the instructions in CONTRIBUTING.md.
Additional Resources
To-Dos
[ ] inference script (similar to train/validation)
[ ] DenseNet
[ ] MobileNet
Contact
If you have any questions, please email riccardomusmeci92@gmail.com
.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mlx_image-0.1.8.tar.gz
.
File metadata
- Download URL: mlx_image-0.1.8.tar.gz
- Upload date:
- Size: 61.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fd64a9017801a6ae620c05f4f867987bb6b59a80ff88d0844af4d6b2cad8df22 |
|
MD5 | 678654f23ec55ccfc2bcb4a4a3423013 |
|
BLAKE2b-256 | 696698dd8ffb515ad521cc9cc957dec2b49e2bf611ed9949723513548f3a8a90 |
File details
Details for the file mlx_image-0.1.8-py3-none-any.whl
.
File metadata
- Download URL: mlx_image-0.1.8-py3-none-any.whl
- Upload date:
- Size: 70.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 64d4346e913b7c17a419c09d820f22ac9bd543d80409ba61e6e59c7ff00c9b56 |
|
MD5 | 07ac662fabf80920496e287cd78296c6 |
|
BLAKE2b-256 | fc5599f29393f25f4d74a794a78ece2131c5c5d5a2ccdc3917232f16b3bac346 |