Skip to main content

Apple MLX image models library

Project description

mlx-image

Image models based on Apple MLX framework for Apple Silicon machines.

Why?

Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.

This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz/.safetensors and the models are not trained again.

How to install

pip install mlx-image

Models

Model weights are available on the mlx-vision community on HuggingFace.

To load a model with pre-trained weights:

from mlxim.model import create_model

# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF

# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")

To list all available models:

from mlxim.model import list_models
list_models()

Supported models

List of all models available in mlx-image:

  • ResNet: resnet18, resnet34, resnet50, resnet101, resnet152, wide_resnet50_2, wide_resnet101_2
  • ViT:
    • supervised: vit_base_patch16_224, vit_base_patch16_224.swag_lin, vit_base_patch16_384.swag_e2e, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch16_224, vit_large_patch16_224.swag_lin, vit_large_patch16_512.swag_e2e, vit_huge_patch14_224.swag_lin, vit_huge_patch14_518.swag_e2e

    • DINO v1: vit_base_patch16_224.dino, vit_small_patch16_224.dino, vit_small_patch8_224.dino, vit_base_patch8_224.dino

    • DINO v2: vit_small_patch14_518.dinov2, vit_base_patch14_518.dinov2, vit_large_patch14_518.dinov2

  • Swin: swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224, swin_v2_tiny_patch4_window8_256, swin_v2_small_patch4_window8_256, swin_v2_base_patch4_window8_256
  • RegNet: regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, regnet_x_16gf, regnet_x_32gf, regnet_y_400mf, regnet_y_800mf, regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, regnet_y_128gf
  • EfficientNet: efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
  • MobileNet: mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small

Warning: The regnet_y_128gf model couldn't be tested due to computational limitations.

ImageNet-1K Results

Go to results-imagenet-1k.csv to check every model converted to mlx-image and its performance on ImageNet-1K with different settings.

TL;DR performance is comparable to the original models from PyTorch implementations.

Similarity to PyTorch and other familiar tools

mlx-image tries to be as close as possible to PyTorch:

  • DataLoader -> you can define your own collate_fn and also use num_workers to speed up data loading
  • Dataset -> mlx-image already supports LabelFolderDataset (the good and old PyTorch ImageFolder) and FolderDataset (a generic folder with images in it)
  • ModelCheckpoint -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping

Training

Training is similar to PyTorch. Here's an example of how to train a model:

import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
    root_dir="path/to/train",
    class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
    logits = model(inputs)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    return loss

model.train()
for epoch in range(10):
    for batch in train_loader:
        x, target = batch
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)

Validation

The validation.py script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.

I use the configuration file config/validation.yaml to set the parameters for the validation script.

You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this link.

Contributing

This is a work in progress, so any help is appreciated.

I am working on it in my spare time, so I can't guarantee frequent updates.

If you love coding and want to contribute, follow the instructions in CONTRIBUTING.md.

Additional Resources

To-Dos

[ ] inference script (similar to train/validation)

[ ] DenseNet

[x] MobileNet

Contact

If you have any questions, please email riccardomusmeci92@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_image-0.1.10.tar.gz (63.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_image-0.1.10-py3-none-any.whl (78.9 kB view details)

Uploaded Python 3

File details

Details for the file mlx_image-0.1.10.tar.gz.

File metadata

  • Download URL: mlx_image-0.1.10.tar.gz
  • Upload date:
  • Size: 63.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.17

File hashes

Hashes for mlx_image-0.1.10.tar.gz
Algorithm Hash digest
SHA256 f6664d68a0f05760f341ee4810236bf9ba253e8a3b8f9f6883470f82bac070f2
MD5 6fdcc60710addcb725c467c05274e687
BLAKE2b-256 3c4ce8c8ccf7e146a5ac7c9d4e21ccdbed793d63989045033a5f879d85c96dea

See more details on using hashes here.

File details

Details for the file mlx_image-0.1.10-py3-none-any.whl.

File metadata

  • Download URL: mlx_image-0.1.10-py3-none-any.whl
  • Upload date:
  • Size: 78.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.17

File hashes

Hashes for mlx_image-0.1.10-py3-none-any.whl
Algorithm Hash digest
SHA256 05fdb73a008195b670d3ca88d401458b44146cb6aed9647a8cf7db08caf1399e
MD5 270225a9de886c865547d10ca0689b2d
BLAKE2b-256 8f08bcf884a4f0bca132d17ebee492b96be1b5817a0908e31fcac7ac0e3b8f4a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page