Skip to main content

Fast gradient boosting library with native Rust core

Project description

Boosters Python

Python 3.12+ License: MIT Documentation

Fast gradient boosting library with native Rust core. Provides both a core API and sklearn-compatible estimators.

📚 Full Documentation — See the main documentation for tutorials, API reference, and guides.

Features

  • High Performance: Native Rust core with Python bindings via PyO3
  • sklearn Compatible: Works with Pipeline, cross_val_score, GridSearchCV
  • Multiple Objectives: Regression, classification, ranking, quantile regression
  • GBDT & Linear: Both tree-based and linear boosting models

Installation

# Development install from workspace root
cd /path/to/booste-rs
uv run maturin develop -m packages/boosters-python/Cargo.toml

Quick Start

sklearn API (Recommended)

The sklearn-compatible estimators provide familiar flat kwargs:

from boosters.sklearn import GBDTRegressor, GBDTClassifier
import numpy as np

# Regression
X = np.random.randn(100, 5).astype(np.float32)
y = X[:, 0] + np.random.randn(100).astype(np.float32) * 0.1

reg = GBDTRegressor(max_depth=5, n_estimators=100)
reg.fit(X, y)
predictions = reg.predict(X)

# Binary classification
y_cls = (X[:, 0] > 0).astype(int)
clf = GBDTClassifier(n_estimators=50)
clf.fit(X, y_cls)
proba = clf.predict_proba(X)

# Multiclass classification (explicit objective required)
from boosters import Objective
y_multi = np.random.randint(0, 3, size=100)
clf_multi = GBDTClassifier(
    n_estimators=50,
    objective=Objective.softmax(n_classes=3)
)
clf_multi.fit(X, y_multi)

Core API

The core API provides full control with flat config parameters:

import boosters as bst
import numpy as np

# Create config
config = bst.GBDTConfig(
    n_estimators=100,
    learning_rate=0.1,
    objective=bst.Objective.squared(),
    metric=bst.Metric.rmse(),
    max_depth=5,
    l2=1.0,
)

# Create model and train
X = np.random.randn(100, 10).astype(np.float32)
y = np.random.randn(100).astype(np.float32)

model = bst.GBDTModel(config=config)
train_data = bst.Dataset(X, y)
model.fit(train_data)

# Predict
predictions = model.predict(bst.Dataset(X))

sklearn Integration

Works seamlessly with sklearn tools:

from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from boosters.sklearn import GBDTRegressor

# Cross-validation
scores = cross_val_score(GBDTRegressor(), X, y, cv=5)

# Pipeline
pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('model', GBDTRegressor(n_estimators=50)),
])
pipe.fit(X, y)

# Grid search
param_grid = {'n_estimators': [50, 100], 'max_depth': [3, 5]}
grid = GridSearchCV(GBDTRegressor(), param_grid, cv=3)
grid.fit(X, y)

API Reference

sklearn Estimators

Class Description
GBDTRegressor Gradient boosted trees for regression
GBDTClassifier Gradient boosted trees for classification
GBLinearRegressor Gradient boosted linear model for regression
GBLinearClassifier Gradient boosted linear model for classification

Core Types

Class Description
GBDTModel Tree-based gradient boosting model
GBLinearModel Linear gradient boosting model
GBDTConfig Configuration for GBDT models
GBLinearConfig Configuration for linear models
Dataset Data wrapper for features and labels

Objectives

Method Description
Objective.squared() L2 regression
Objective.absolute() L1 regression
Objective.huber(delta) Huber loss
Objective.logistic() Binary classification
Objective.softmax(n_classes) Multiclass classification
Objective.poisson() Poisson regression
Objective.pinball(alpha) Quantile regression
Objective.lambdarank(ndcg_at) Learning to rank

Metrics

Method Description
Metric.rmse() Root mean squared error
Metric.mae() Mean absolute error
Metric.logloss() Log loss / cross-entropy
Metric.auc() Area under ROC curve
Metric.Accuracy() Classification accuracy
Metric.ndcg(at) Normalized discounted cumulative gain

Examples

See the examples/ directory:

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

boosters-0.1.0.tar.gz (388.3 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

boosters-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB view details)

Uploaded CPython 3.12+manylinux: glibc 2.17+ x86-64

boosters-0.1.0-cp312-abi3-macosx_11_0_arm64.whl (1.2 MB view details)

Uploaded CPython 3.12+macOS 11.0+ ARM64

File details

Details for the file boosters-0.1.0.tar.gz.

File metadata

  • Download URL: boosters-0.1.0.tar.gz
  • Upload date:
  • Size: 388.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for boosters-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bee21b178b46566478dc4ab8d0cabe1d74204f7227d0d771500ae264f7820d74
MD5 1ef876b1f67e54e8a95d35986ddba201
BLAKE2b-256 8042915bcc993f87454b490cec5476a583388d1fb01ed1b8b23e415e2e05c7dd

See more details on using hashes here.

Provenance

The following attestation bundles were made for boosters-0.1.0.tar.gz:

Publisher: python.yml on egordm/boosters

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file boosters-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for boosters-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c415f36c3b6588ab17a2007231c48f99a10fa62edf9de67340cccfb23e723b35
MD5 deac5ba4c1ddf8cfc12af61c74c2b56e
BLAKE2b-256 c705f63ab47a35f087d2800da9183a87f18693d75db952f3837c51a522854786

See more details on using hashes here.

Provenance

The following attestation bundles were made for boosters-0.1.0-cp312-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: python.yml on egordm/boosters

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file boosters-0.1.0-cp312-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for boosters-0.1.0-cp312-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 9bd5ba708874182741feb1b11ebc3e35780301867482fb62f071442de6e1ef98
MD5 5a1a747d521a58a65427a2b0c5b0fb71
BLAKE2b-256 2a5346c094109782db2485c3b5ea89611574ad8966266470b6296485d8688317

See more details on using hashes here.

Provenance

The following attestation bundles were made for boosters-0.1.0-cp312-abi3-macosx_11_0_arm64.whl:

Publisher: python.yml on egordm/boosters

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page