Skip to main content

Graph Neural Network library made for Apple Silicon

Project description

mlx-graphs logo


Documentation | Quickstart | Discord

MLX-graphs is a library for Graph Neural Networks (GNNs) built upon Apple's MLX.

Features

  • Fast GNN training and inference on Apple Silicon

    mlx-graphs has been designed to run GNNs and graph algorithms fast on Apple Silicon chips. All GNN operations fully leverage the GPU and CPU hardware of Macs thanks to the efficient low-level primitives available within the MLX core library. Initial benchmarks show an up to 10x speed improvement with respect to other frameworks on large datasets.

  • Scalability to large graphs

    With unified memory architecture, objects live in a shared memory accessible by both the CPU and GPU. This setup allows Macs to leverage their entire memory capacity for storing graphs. Consequently, Macs equipped with substantial memory can efficiently train GNNs on large graphs, spanning tens of gigabytes, directly using the Mac's GPU.

  • Multi-device

    Unified memory eliminates the need for time-consuming device-to-device transfers. This architecture also enables specific operations to be run explicitly on either the CPU or GPU without incurring any overhead, facilitating more efficient computation and resource utilization.

Installation

mlx-graphs is available on Pypi. To install run

pip install mlx-graphs

Build from source

To build and install mlx-graphs from source start by cloning the github repo

git clone git@github.com:mlx-graphs/mlx-graphs.git && cd mlx-graphs

Create a new virtual environment and install the requirements

pip install -e .

Usage

Tutorial guides

We provide some notebooks to practice mlx-graphs.

Example

This library has been designed to build GNNs with ease and efficiency. Building new GNN layers is straightforward by implementing the MessagePassing class. This approach ensures that all operations related to message passing are properly handled and processed efficiently on your Mac's GPU. As a result, you can focus exclusively on the GNN logic, without worrying about the underlying message passing mechanics.

Here is an example of a custom GraphSAGE convolutional layer that considers edge weights:

import mlx.core as mx
from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing

class SAGEConv(MessagePassing):
    def __init__(
        self, node_features_dim: int, out_features_dim: int, bias: bool = True, **kwargs
    ):
        super(SAGEConv, self).__init__(aggr="mean", **kwargs)

        self.node_features_dim = node_features_dim
        self.out_features_dim = out_features_dim

        self.neigh_proj = Linear(node_features_dim, out_features_dim, bias=False)
        self.self_proj = Linear(node_features_dim, out_features_dim, bias=bias)

    def __call__(self, edge_index: mx.array, node_features: mx.array, edge_weights: mx.array) -> mx.array:
         """Forward layer of the custom SAGE layer."""
         neigh_features = self.propagate( # Message passing directly on GPU
            edge_index=edge_index,
            node_features=node_features,
            message_kwargs={"edge_weights": edge_weights},
         )
         neigh_features = self.neigh_proj(neigh_features)

        out_features = self.self_proj(node_features) + neigh_features
        return out_features

   def message(self, src_features: mx.array, dst_features: mx.array, **kwargs) -> mx.array:
         """Message function called by propagate(). Computes messages for all edges in the graph."""
        edge_weights = kwargs.get("edge_weights", None)

        return edge_weights.reshape(-1, 1) * src_features

Contributing

Why contributing?

We are at an early stage of the development of the lib, which means your contributions can have a large impact! Everyone is welcome to contribute, just open an issue 📝 with your idea 💡 and we'll work together on the implementation ✨.

[!NOTE] Contributions such as the implementation of new layers and datasets would be very valuable for the library.

Installing test, dev, benchmaks, docs dependencies

Extra dependencies are specified in the pyproject.toml. To install those required for testing, development and building documentation, you can run any of the following

pip install -e '.[test]'
pip install -e '.[dev]'
pip install -e '.[benchmarks]'
pip install -e '.[docs]'

For dev purposes you may want to install the current version of mlx via pip install git+https://github.com/ml-explore/mlx.git

Testing

We encourage to write tests for all components. Please run pytest to ensure breaking changes are not introduced.

Note: CI is in place to automatically run tests upon opening a PR.

Pre-commit hooks (optional)

To ensure code quality you can run pre-commit hooks. Simply install them by running

pre-commit install

and run via pre-commit run --all-files.

Note: CI is in place to verify code quality, so pull requests that don't meet those requirements won't pass CI tests.

Why running GNNs on my Mac?

Other frameworks like PyG and DGL also benefit from efficient GNN operations parallelized on GPU. However, they are not fully optimized to leverage the Mac's GPU capabilities, often defaulting to CPU execution.

In contrast, mlx-graphs is specifically designed to leverage the power of Mac's hardware, delivering optimal performance for Mac users. By taking advantage of Apple Silicon, mlx-graphs enables accelerated GPU computation and benefits from unified memory. This approach removes the need for data transfers between devices and allows for the use of the entire memory space available on the Mac's GPU. Consequently, users can manage large graphs directly on the GPU, enhancing performance and efficiency.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_graphs-0.0.7.tar.gz (59.1 kB view details)

Uploaded Source

Built Distribution

mlx_graphs-0.0.7-py3-none-any.whl (78.2 kB view details)

Uploaded Python 3

File details

Details for the file mlx_graphs-0.0.7.tar.gz.

File metadata

  • Download URL: mlx_graphs-0.0.7.tar.gz
  • Upload date:
  • Size: 59.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for mlx_graphs-0.0.7.tar.gz
Algorithm Hash digest
SHA256 4282b829c0dd1b5c4e000cb360762fe66590edb7f7db551825bc246b107a4464
MD5 9e6d1d6927dbdb54c70f707256a76c10
BLAKE2b-256 040f36ac70ad55a8667332d12b347569f33d99ea6894035e775fd288d8152aef

See more details on using hashes here.

File details

Details for the file mlx_graphs-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: mlx_graphs-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 78.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for mlx_graphs-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 cda50ddf30a152dc5b33d21083d35dbd8e3d48b89631e429e58904944a244ccb
MD5 c3383fb7c058ad4527eb78c3493a5b2d
BLAKE2b-256 61fb9bdcd46a4d8dfd2df607a3bf23c5ab33d87c7a5f9233633604d71d7f413c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page