Skip to main content

C++ and Metal extensions for MLX CTC Loss

Project description

mlx_cluster

A C++ extension for generating random walks for Homogeneous graphs using mlx

Installation

To install the necessary dependencies:

Clone the repositories:

git clone https://github.com/vinayhpandya/mlx_cluster.git

After cloning the repository install library using

python setup.py build_ext -j8 --inplace

You can also just install the library via pip

pip install mlx_cluster

for testing purposes you need to have mlx-graphs and torch_geometric installed

Usage

from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk


cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)

TODO

  • Add metal shaders to optimize the code
  • Benchmark random walk against different frameworks
  • Add more algorithms

Credits:

torch_cluster random walk implementation : random_walk

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_cluster-0.0.2.tar.gz (69.1 kB view details)

Uploaded Source

Built Distribution

mlx_cluster-0.0.2-cp311-cp311-macosx_14_0_arm64.whl (61.6 kB view details)

Uploaded CPython 3.11 macOS 14.0+ ARM64

File details

Details for the file mlx_cluster-0.0.2.tar.gz.

File metadata

  • Download URL: mlx_cluster-0.0.2.tar.gz
  • Upload date:
  • Size: 69.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.4

File hashes

Hashes for mlx_cluster-0.0.2.tar.gz
Algorithm Hash digest
SHA256 26250fcdbd20a5c2c2df6f06f15d210ff49a8dacdc13b52df6dc14e040e47c53
MD5 a2807f50e543b0218431d20732cc0c88
BLAKE2b-256 35bb7469986e697efed58a096ec2001771bb11964497c6fb8e22d06140104f17

See more details on using hashes here.

File details

Details for the file mlx_cluster-0.0.2-cp311-cp311-macosx_14_0_arm64.whl.

File metadata

File hashes

Hashes for mlx_cluster-0.0.2-cp311-cp311-macosx_14_0_arm64.whl
Algorithm Hash digest
SHA256 1f9e0522fc30afa3a38e185629dcd2d1fc4a622b9460bfd11ee7feb5281a7504
MD5 508f81e7be771c899cbce9f7a598108a
BLAKE2b-256 28a21d1c76bce42bff5af0bb34799754f87641eb13c845cb0b7b1f2f43e061cd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page