C++ and Metal extensions for MLX CTC Loss
Project description
mlx_cluster
A C++ extension for generating ramdom walks for Homogeneous graphs using mlx
Installation
To install the necessary dependencies:
Clone the repositories:
git clone https://github.com/vinayhpandya/mlx_cluster.git
After cloning the repository install library using
python setup.py build_ext -j8 --inplace
for testing purposes you need to have mlx-graphs
installed
Usage
from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
TODO
- Add metal shaders to optimize the code
- Benchmark random walk against different frameworks
- Add more algorithms
Credits:
torch_cluster random walk implementation : random_walk
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mlx_cluster-0.0.1.tar.gz
(68.9 kB
view details)
Built Distribution
File details
Details for the file mlx_cluster-0.0.1.tar.gz
.
File metadata
- Download URL: mlx_cluster-0.0.1.tar.gz
- Upload date:
- Size: 68.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | efbd2ae16ac30360b4094294e8e62e114fd04f8a81591cf41355f93da08adfdb |
|
MD5 | 5dae9a7732220c02dc88aa56ce7038f5 |
|
BLAKE2b-256 | 5a1b5681feb955790c779ae141cfd979b9a1810577b487ff913e5bd0bfec653b |
File details
Details for the file mlx_cluster-0.0.1-cp311-cp311-macosx_14_0_arm64.whl
.
File metadata
- Download URL: mlx_cluster-0.0.1-cp311-cp311-macosx_14_0_arm64.whl
- Upload date:
- Size: 59.6 kB
- Tags: CPython 3.11, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 59f4f192e3d770566fe468c09566785d235485c7ac1309188c30d1ab850c4108 |
|
MD5 | 7533233fc7f656c2d81ad2597216c4a7 |
|
BLAKE2b-256 | 34cf8b32f92a691887926747abb111ee811f2695be5e1210639e2ad44ad8a015 |