C++ and Metal extensions for MLX CTC Loss
Project description
mlx_cluster
A C++ extension for generating random walks for Homogeneous graphs using mlx
Installation
To install the necessary dependencies:
Clone the repositories:
git clone https://github.com/vinayhpandya/mlx_cluster.git
After cloning the repository install library using
python setup.py build_ext -j8 --inplace
You can also just install the library via pip
pip install mlx_cluster
for testing purposes you need to have mlx-graphs
and torch_geometric
installed
Usage
from mlx_graphs.utils.sorting import sort_edge_index
from mlx_graphs.loaders import Dataloader
from mlx_graphs_extension import random_walk
cora_dataset = PlanetoidDataset(name="cora", base_dir="~")
start = mx.arange(0, 1000)
start_time = time.time()
edge_index = cora_dataset.graphs[0].edge_index.astype(mx.int64)
num_nodes = cora_dataset.graphs[0].num_nodes
sorted_edge_index = sort_edge_index(edge_index=edge_index)
row_mlx = sorted_edge_index[0][0]
col_mlx = sorted_edge_index[0][1]
unique_vals, counts_mlx = np.unique(np.array(row_mlx, copy=False), return_counts=True)
cum_sum_mlx = counts_mlx.cumsum()
rand = mx.random.uniform(shape=[start.shape[0], 100])
row_ptr_mlx = mx.concatenate([mx.array([0]), mx.array(cum_sum_mlx)])
random_walk(row_ptr_mlx, col_mlx, start, rand, 1000, stream = mx.gpu)
TODO
- Add metal shaders to optimize the code
- Benchmark random walk against different frameworks
- Add more algorithms
Credits:
torch_cluster random walk implementation : random_walk
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mlx_cluster-0.0.2.tar.gz
(69.1 kB
view details)
Built Distribution
File details
Details for the file mlx_cluster-0.0.2.tar.gz
.
File metadata
- Download URL: mlx_cluster-0.0.2.tar.gz
- Upload date:
- Size: 69.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26250fcdbd20a5c2c2df6f06f15d210ff49a8dacdc13b52df6dc14e040e47c53 |
|
MD5 | a2807f50e543b0218431d20732cc0c88 |
|
BLAKE2b-256 | 35bb7469986e697efed58a096ec2001771bb11964497c6fb8e22d06140104f17 |
File details
Details for the file mlx_cluster-0.0.2-cp311-cp311-macosx_14_0_arm64.whl
.
File metadata
- Download URL: mlx_cluster-0.0.2-cp311-cp311-macosx_14_0_arm64.whl
- Upload date:
- Size: 61.6 kB
- Tags: CPython 3.11, macOS 14.0+ ARM64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1f9e0522fc30afa3a38e185629dcd2d1fc4a622b9460bfd11ee7feb5281a7504 |
|
MD5 | 508f81e7be771c899cbce9f7a598108a |
|
BLAKE2b-256 | 28a21d1c76bce42bff5af0bb34799754f87641eb13c845cb0b7b1f2f43e061cd |