C++ and Metal extensions for MLX CTC Loss
Project description
mlx_cluster
A C++ extension for generating random walks for Homogeneous graphs using mlx
Installation
To install the necessary dependencies:
Clone the repositories:
git clone https://github.com/vinayhpandya/mlx_cluster.git
After cloning the repository install library using
python setup.py build_ext -j8 --inplace
You can also just install the library via pip
pip install mlx_cluster
for testing purposes you need to have mlx-graphs
and torch_geometric
installed
Usage
from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
TODO
- Add metal shaders to optimize the code
- Benchmark random walk against different frameworks
- Add more algorithms
Credits:
torch_cluster random walk implementation : random_walk
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mlx_cluster-0.0.4.tar.gz
(3.2 MB
view details)
Built Distribution
File details
Details for the file mlx_cluster-0.0.4.tar.gz
.
File metadata
- Download URL: mlx_cluster-0.0.4.tar.gz
- Upload date:
- Size: 3.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1df26921a0895dfa115b401b800d7ffae15b29a04c987fc57c71985f5c8f824 |
|
MD5 | acfed22e6fcb299d54c92963f4468e64 |
|
BLAKE2b-256 | e156c58a3f7090a720959465f37b8dea2e738bac63ce12a20c4de1e0cf24296f |
File details
Details for the file mlx_cluster-0.0.4-cp311-cp311-macosx_14_0_arm64.whl
.
File metadata
- Download URL: mlx_cluster-0.0.4-cp311-cp311-macosx_14_0_arm64.whl
- Upload date:
- Size: 3.2 MB
- Tags: CPython 3.11, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 40ee1928c21d3bc9659082554c5908b13b78b32294c9dc3703cdf1f7fed4c9cd |
|
MD5 | 2cedf157c06168ac99ce8a87546deaca |
|
BLAKE2b-256 | 8f5c2066b3f05d632d2765d2bb451fc9e0fc66a9fe62ceab09abf504158d6233 |