Skip to main content

C++ and Metal extensions for MLX CTC Loss

Project description

mlx_cluster

A C++ extension for generating ramdom walks for Homogeneous graphs using mlx

Installation

To install the necessary dependencies:

Clone the repositories:

git clone https://github.com/vinayhpandya/mlx_cluster.git

After cloning the repository install library using

python setup.py build_ext -j8 --inplace

for testing purposes you need to have mlx-graphs installed

Usage

from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk


cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)

TODO

  • Add metal shaders to optimize the code
  • Benchmark random walk against different frameworks
  • Add more algorithms

Credits:

torch_cluster random walk implementation : random_walk

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_cluster-0.0.1.tar.gz (68.9 kB view details)

Uploaded Source

Built Distribution

mlx_cluster-0.0.1-cp311-cp311-macosx_14_0_arm64.whl (59.6 kB view details)

Uploaded CPython 3.11 macOS 14.0+ ARM64

File details

Details for the file mlx_cluster-0.0.1.tar.gz.

File metadata

  • Download URL: mlx_cluster-0.0.1.tar.gz
  • Upload date:
  • Size: 68.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.4

File hashes

Hashes for mlx_cluster-0.0.1.tar.gz
Algorithm Hash digest
SHA256 efbd2ae16ac30360b4094294e8e62e114fd04f8a81591cf41355f93da08adfdb
MD5 5dae9a7732220c02dc88aa56ce7038f5
BLAKE2b-256 5a1b5681feb955790c779ae141cfd979b9a1810577b487ff913e5bd0bfec653b

See more details on using hashes here.

File details

Details for the file mlx_cluster-0.0.1-cp311-cp311-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for mlx_cluster-0.0.1-cp311-cp311-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 59f4f192e3d770566fe468c09566785d235485c7ac1309188c30d1ab850c4108
MD5 7533233fc7f656c2d81ad2597216c4a7
BLAKE2b-256 34cf8b32f92a691887926747abb111ee811f2695be5e1210639e2ad44ad8a015

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page