GPU-accelerated boosting algorithms using Apple MLX for Apple Silicon
Project description
MLX-Boosting
GPU-accelerated gradient boosting algorithms for Apple Silicon, built on Apple MLX.
Features
- XGBoost-style implementation with second-order gradients
- Gradient Boosted Decision Trees (GBDT) for regression and classification
- Decision Trees as standalone estimators
- Optimized for Apple Silicon (M1/M2/M3/M4) using MLX
- Numba JIT compilation for fast tree building
- scikit-learn compatible API
Installation
pip install mlx-boosting
Requirements:
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.10+
Quick Start
XGBoost Regressor
import mlx.core as mx
from mlx_boosting import XGBoostRegressor
# Create sample data
X = mx.random.normal((1000, 10))
y = mx.random.normal((1000,))
# Train model
model = XGBoostRegressor(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
)
model.fit(X, y)
# Predict
predictions = model.predict(X)
XGBoost Classifier
import mlx.core as mx
from mlx_boosting import XGBoostClassifier
# Binary classification
X = mx.random.normal((1000, 10))
y = mx.array((mx.random.uniform((1000,)) > 0.5).astype(mx.int32))
model = XGBoostClassifier(n_estimators=100, max_depth=6)
model.fit(X, y)
# Predict probabilities
probs = model.predict_proba(X)
# Predict classes
classes = model.predict(X)
Gradient Boosting
from mlx_boosting import GradientBoostingRegressor, GradientBoostingClassifier
# Regression
reg = GradientBoostingRegressor(n_estimators=100, max_depth=4)
reg.fit(X, y)
# Classification
clf = GradientBoostingClassifier(n_estimators=100, max_depth=4)
clf.fit(X, y_class)
Decision Trees
from mlx_boosting import DecisionTreeRegressor, DecisionTreeClassifier
# Standalone decision tree
tree = DecisionTreeRegressor(max_depth=6)
tree.fit(X, y)
predictions = tree.predict(X)
Parameters
XGBoostRegressor / XGBoostClassifier
| Parameter | Default | Description |
|---|---|---|
n_estimators |
100 | Number of boosting rounds |
max_depth |
6 | Maximum tree depth |
learning_rate |
0.3 | Step size shrinkage |
min_child_weight |
1.0 | Minimum sum of instance weight in a child |
reg_lambda |
1.0 | L2 regularization term |
reg_alpha |
0.0 | L1 regularization term |
gamma |
0.0 | Minimum loss reduction for split |
subsample |
1.0 | Subsample ratio of training instances |
colsample_bytree |
1.0 | Subsample ratio of columns per tree |
n_bins |
256 | Number of histogram bins |
GradientBoostingRegressor / GradientBoostingClassifier
| Parameter | Default | Description |
|---|---|---|
n_estimators |
100 | Number of boosting rounds |
max_depth |
3 | Maximum tree depth |
learning_rate |
0.1 | Step size shrinkage |
min_samples_split |
2 | Minimum samples required to split |
min_samples_leaf |
1 | Minimum samples required in a leaf |
Performance
MLX-Boosting is optimized for Apple Silicon and achieves excellent performance on high-volume datasets:
| Dataset Size | vs sklearn |
|---|---|
| 10K samples | ~1.5x faster |
| 50K samples | ~2x faster |
| 100K samples | up to 3x faster |
MLX-Boosting achieves up to 3x faster training on high-volume data compared to sklearn's GradientBoosting, running natively on Apple Silicon.
Working with NumPy
MLX-Boosting works seamlessly with NumPy arrays:
import numpy as np
import mlx.core as mx
from mlx_boosting import XGBoostRegressor
# NumPy data
X_np = np.random.randn(1000, 10).astype(np.float32)
y_np = np.random.randn(1000).astype(np.float32)
# Convert to MLX
X = mx.array(X_np)
y = mx.array(y_np)
# Train
model = XGBoostRegressor(n_estimators=100)
model.fit(X, y)
# Predictions back to NumPy
preds = np.array(model.predict(X))
API Reference
Classes
XGBoostRegressor- XGBoost-style regressionXGBoostClassifier- XGBoost-style classification (binary and multiclass)GradientBoostingRegressor- GBDT regressionGradientBoostingClassifier- GBDT classificationDecisionTreeRegressor- Decision tree regressionDecisionTreeClassifier- Decision tree classification
Common Methods
fit(X, y)- Train the modelpredict(X)- Make predictionspredict_proba(X)- Predict probabilities (classifiers only)
License
MIT License - see LICENSE for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Acknowledgments
- Apple MLX - The foundation for GPU acceleration
- XGBoost - Inspiration for the algorithm implementation
- scikit-learn - API design patterns
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_boosting-1.0.0.tar.gz.
File metadata
- Download URL: mlx_boosting-1.0.0.tar.gz
- Upload date:
- Size: 119.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
df8d710a0ee20cea5c601f234fa24df5381cda655af09717959355764a93bb61
|
|
| MD5 |
df5f94adc9bbe204b7da3ea53739ad0e
|
|
| BLAKE2b-256 |
79e7e7b9c1c17bcae12ef93fa23ecc182f997718b9373abf802b00153d9980f6
|
File details
Details for the file mlx_boosting-1.0.0-py3-none-any.whl.
File metadata
- Download URL: mlx_boosting-1.0.0-py3-none-any.whl
- Upload date:
- Size: 51.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.17 {"installer":{"name":"uv","version":"0.9.17","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
350bf2895fd9eb5c16c94c3483389edafe2bc2f73853634cc33e1f87343a8b00
|
|
| MD5 |
1fbdfa7e33e6e6265da21e63ca93bde7
|
|
| BLAKE2b-256 |
ab0c38235176e6d28cde350460a9fcdf33e91b1243bf3b1b0a1d34f754dcc58a
|