Skip to main content

Apple mlx image models library

Project description

mlxim

Image models based on Apple MLX framework for Apple Silicon machines.

Why? 💡

Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.

This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz and the models are not trained again.

I don't have enough compute power (and time) to train all the models from scratch (someone buy me a maxed-out Mac, please).

How to install

pip install mlx-image

Models

Models weights are available on mlx-vision space on HuggingFace.

To create a model with weights:

from mlxim.model import create_model

# loading weights from HuggingFace
model = create_model("resnet18")

# loading weights from local file
model = create_model("resnet18", weights="path/to/weights.npz")

To list all available models:

from mlxim.model import list_models
list_models()

[!WARNING] As of today (2024-03-05) mlx does not support nn.Conv2d with group or dilation greater than 1 (e.g. resnext, regnet, efficientnet).

ImageNet-1K Results

Go to results-imagenet-1k.csv to check every model converted and its performance on ImageNet-1K with different settings.

Train

Training is similar to PyTorch, thanks to some tools mlx-im provides. Here's an example of how to train a model with mlx-im:

import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
    root_dir="path/to/train",
    class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
    logits = model(inputs)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    return loss

model.train()
for epoch in range(10):
    for batch in train_loader:
        x, target = batch
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)

Validation

The validation.py script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.

I use the configuration file config/validation.yaml to set the parameters for the validation script.

You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this link.

Similarity to PyTorch and other familiar tools

mlx-im tries to be as close as possible to PyTorch:

  • DataLoader -> you can define your own collate_fn and also use num_workers to speed up the data loading
  • Dataset -> mlx-im already supports LabelFolderDataset (the good and old PyTorch ImageFolder) and FolderDataset (a generic folder with images in it)
  • ModelCheckpoint -> keeps track of the best model and saves it to disk (similar to PyTorchLightning) and it also suggests early stopping

Contributing

This is a work in progress, so any help is appreciated.

I am working on it in my spare time, so I can't guarantee frequent updates.

If you love coding and want to contribute, follow the instructions in CONTRIBUTING.md.

To-Do

[ ] add register_model (similar to timm)

[ ] inference script (similar to train/validation)

[ ] SwinTransformer

[ ] DenseNet

[ ] MobileNet

Contact

If you have any questions, please email riccardomusmeci92@gmail.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_image-0.1.1.tar.gz (47.0 kB view hashes)

Uploaded Source

Built Distribution

mlx_image-0.1.1-py3-none-any.whl (53.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page