Skip to main content

C++ and Metal extensions for MLX CTC Loss

Project description

mlx_cluster

A C++ extension for generating random walks for Homogeneous graphs using mlx

Installation

To install the necessary dependencies:

Clone the repositories:

git clone https://github.com/vinayhpandya/mlx_cluster.git

After cloning the repository install library using

python setup.py build_ext -j8 --inplace

You can also just install the library via pip

pip install mlx_cluster

for testing purposes you need to have mlx-graphs and torch_geometric installed

Usage

from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk


cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)

TODO

  • Add metal shaders to optimize the code
  • Benchmark random walk against different frameworks
  • Add more algorithms

Credits:

torch_cluster random walk implementation : random_walk

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_cluster-0.0.4.tar.gz (3.2 MB view details)

Uploaded Source

Built Distribution

mlx_cluster-0.0.4-cp311-cp311-macosx_14_0_arm64.whl (3.2 MB view details)

Uploaded CPython 3.11 macOS 14.0+ ARM64

File details

Details for the file mlx_cluster-0.0.4.tar.gz.

File metadata

  • Download URL: mlx_cluster-0.0.4.tar.gz
  • Upload date:
  • Size: 3.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.4

File hashes

Hashes for mlx_cluster-0.0.4.tar.gz
Algorithm Hash digest
SHA256 a1df26921a0895dfa115b401b800d7ffae15b29a04c987fc57c71985f5c8f824
MD5 acfed22e6fcb299d54c92963f4468e64
BLAKE2b-256 e156c58a3f7090a720959465f37b8dea2e738bac63ce12a20c4de1e0cf24296f

See more details on using hashes here.

File details

Details for the file mlx_cluster-0.0.4-cp311-cp311-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for mlx_cluster-0.0.4-cp311-cp311-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 40ee1928c21d3bc9659082554c5908b13b78b32294c9dc3703cdf1f7fed4c9cd
MD5 2cedf157c06168ac99ce8a87546deaca
BLAKE2b-256 8f5c2066b3f05d632d2765d2bb451fc9e0fc66a9fe62ceab09abf504158d6233

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page