Skip to main content

Graph Neural Network library made for Apple Silicon

Project description

mlx-graphs logo


Documentation | Quickstart

MLX-graphs is a library for Graph Neural Networks (GNNs) built upon Apple's MLX.

Features

  • Fast GNN training and inference on Apple Silicon

    mlx-graphs has been designed to run GNNs and graph algorithms fast on Apple Silicon chips. All GNN operations fully leverage the GPU and CPU hardware of Macs thanks to the efficient low-level primitives available within the MLX core library. Initial benchmarks show an up to 10x speed improvement with respect to other frameworks on large datasets.

  • Scalability to large graphs

    With unified memory architecture, objects live in a shared memory accessible by both the CPU and GPU. This setup allows Macs to leverage their entire memory capacity for storing graphs. Consequently, Macs equipped with substantial memory can efficiently train GNNs on large graphs, spanning tens of gigabytes, directly using the Mac's GPU.

  • Multi-device

    Unified memory eliminates the need for time-consuming device-to-device transfers. This architecture also enables specific operations to be run explicitly on either the CPU or GPU without incurring any overhead, facilitating more efficient computation and resource utilization.

Installation

mlx-graphs is available on Pypi. To install run

pip install mlx-graphs

Build from source

To build and install mlx-graphs from source start by cloning the github repo

git clone git@github.com:mlx-graphs/mlx-graphs.git && cd mlx-graphs

Create a new virtual environment and install the requirements

pip install -e .

Usage

Tutorial guides

We provide some notebooks to practice mlx-graphs.

Example

This library has been designed to build GNNs with ease and efficiency. Building new GNN layers is straightforward by implementing the MessagePassing class. This approach ensures that all operations related to message passing are properly handled and processed efficiently on your Mac's GPU. As a result, you can focus exclusively on the GNN logic, without worrying about the underlying message passing mechanics.

Here is an example of a custom GraphSAGE convolutional layer that considers edge weights:

import mlx.core as mx
from mlx_graphs.nn.linear import Linear
from mlx_graphs.nn.message_passing import MessagePassing

class SAGEConv(MessagePassing):
    def __init__(
        self, node_features_dim: int, out_features_dim: int, bias: bool = True, **kwargs
    ):
        super(SAGEConv, self).__init__(aggr="mean", **kwargs)

        self.node_features_dim = node_features_dim
        self.out_features_dim = out_features_dim

        self.neigh_proj = Linear(node_features_dim, out_features_dim, bias=False)
        self.self_proj = Linear(node_features_dim, out_features_dim, bias=bias)

    def __call__(self, edge_index: mx.array, node_features: mx.array, edge_weights: mx.array) -> mx.array:
         """Forward layer of the custom SAGE layer."""
         neigh_features = self.propagate( # Message passing directly on GPU
            edge_index=edge_index,
            node_features=node_features,
            message_kwargs={"edge_weights": edge_weights},
         )
         neigh_features = self.neigh_proj(neigh_features)

        out_features = self.self_proj(node_features) + neigh_features
        return out_features

   def message(self, src_features: mx.array, dst_features: mx.array, **kwargs) -> mx.array:
         """Message function called by propagate(). Computes messages for all edges in the graph."""
        edge_weights = kwargs.get("edge_weights", None)

        return edge_weights.reshape(-1, 1) * src_features

Contributing

Why contributing?

We are at an early stage of the development of the lib, which means your contributions can have a large impact! Everyone is welcome to contribute, just open an issue 📝 with your idea 💡 and we'll work together on the implementation ✨.

[!NOTE] Contributions such as the implementation of new layers and datasets would be very valuable for the library.

Installing test, dev, benchmaks, docs dependencies

Extra dependencies are specified in the pyproject.toml. To install those required for testing, development and building documentation, you can run any of the following

pip install -e '.[test]'
pip install -e '.[dev]'
pip install -e '.[benchmarks]'
pip install -e '.[docs]'

For dev purposes you may want to install the current version of mlx via pip install git+https://github.com/ml-explore/mlx.git

Testing

We encourage to write tests for all components. CI is currently not in place as runners with Apple Silicon are required. Please run pytest to ensure breaking changes are not introduced.

Pre-commit hooks (optional)

To ensure code quality you can run pre-commit hooks. Simply install them by running

pre-commit install

and run via pre-commit run --all-files.

Note: CI is in place to verify code quality, so pull requests that don't meet those requirements won't pass CI tests.

Why running GNNs on my Mac?

Other frameworks like PyG and DGL also benefit from efficient GNN operations parallelized on GPU. However, they are not fully optimized to leverage the Mac's GPU capabilities, often defaulting to CPU execution.

In contrast, mlx-graphs is specifically designed to leverage the power of Mac's hardware, delivering optimal performance for Mac users. By taking advantage of Apple Silicon, mlx-graphs enables accelerated GPU computation and benefits from unified memory. This approach removes the need for data transfers between devices and allows for the use of the entire memory space available on the Mac's GPU. Consequently, users can manage large graphs directly on the GPU, enhancing performance and efficiency.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx-graphs-0.0.4.tar.gz (41.0 kB view details)

Uploaded Source

Built Distribution

mlx_graphs-0.0.4-py3-none-any.whl (50.9 kB view details)

Uploaded Python 3

File details

Details for the file mlx-graphs-0.0.4.tar.gz.

File metadata

  • Download URL: mlx-graphs-0.0.4.tar.gz
  • Upload date:
  • Size: 41.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for mlx-graphs-0.0.4.tar.gz
Algorithm Hash digest
SHA256 e80bb3bcdae0f5bf9020a2f420046dde2aa46a2d9d780c25d6c60eecd26f84fb
MD5 d41a16a74eeb4a6c577f78d046c64036
BLAKE2b-256 1298addab81b925a356b576f834c6d1f39a36aae70fab259f6ab59255547b598

See more details on using hashes here.

File details

Details for the file mlx_graphs-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: mlx_graphs-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 50.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/4.0.2 CPython/3.11.8

File hashes

Hashes for mlx_graphs-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f7566788134bef1bf39cc0790c65d3fc477c78cdcccfc550c608f600f0b3adde
MD5 7e24129a543db8672ab419e362c3ddc7
BLAKE2b-256 0d1c1aa4f0e10070274b0075747ba60ee60e712f8b42494a80095a721c859648

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page