Next-generation differentiable gradient boosting with JAX
Project description
JAX autodiff for XGBoost/LightGBM objectives.
Write a loss function, get gradients and Hessians automatically. No manual derivation needed.
Works with XGBoost and LightGBM.
Install
pip install jaxboost
Quick Start
XGBoost
import xgboost as xgb
import jax.numpy as jnp
from jaxboost import auto_objective, focal_loss, huber, quantile
# Prepare your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
# Built-in objectives - just use them
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.9).xgb_objective)
# Custom objective - write the loss, autodiff handles the rest
@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
error = y_true - y_pred
return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)
model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)
LightGBM
import lightgbm as lgb
from jaxboost import huber
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)
Available Objectives
Regression
| Objective | Description |
|---|---|
mse |
Mean squared error |
huber |
Huber loss (robust to outliers) |
pseudo_huber |
Smooth approximation of Huber loss |
log_cosh |
Log-cosh loss |
mae_smooth |
Smooth approximation of MAE |
quantile(q) |
Quantile regression |
asymmetric(alpha) |
Asymmetric squared error |
tweedie(p) |
Tweedie deviance |
Binary Classification
| Objective | Description |
|---|---|
focal_loss |
Focal loss for imbalanced data |
binary_crossentropy |
Standard log loss |
weighted_binary_crossentropy |
Weighted binary cross-entropy |
hinge_loss |
SVM-style hinge loss |
Multi-class Classification
| Objective | Description |
|---|---|
softmax_cross_entropy |
Standard multi-class |
focal_multiclass |
Focal loss for multi-class |
label_smoothing(eps) |
Label smoothing regularization |
class_balanced |
Class-balanced loss |
Survival Analysis
| Objective | Description |
|---|---|
aft |
Accelerated failure time (log-normal) |
weibull_aft |
Weibull AFT model |
Ordinal Regression
| Objective | Description |
|---|---|
ordinal_regression |
Cumulative Link Model (probit/logit) |
qwk_ordinal |
QWK-aligned Expected Quadratic Error |
squared_cdf_ordinal |
CRPS / Ranked Probability Score |
hybrid_ordinal |
NLL + EQE hybrid |
Multi-task Learning
| Objective | Description |
|---|---|
multi_task_regression |
Multiple regression targets |
multi_task_classification |
Multiple classification targets |
multi_task_huber |
Multi-task Huber loss |
multi_task_quantile |
Multi-task quantile loss |
MaskedMultiTaskObjective |
Handle missing labels |
Uncertainty Estimation
| Objective | Description |
|---|---|
gaussian_nll |
Predict mean + variance |
laplace_nll |
Predict median + scale |
Ordinal Regression
XGBoost has no native ordinal objective. JAXBoost implements proper Cumulative Link Models:
from jaxboost import ordinal_regression, qwk_ordinal
# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_regression(n_classes=6, link='probit')
ordinal.init_thresholds_from_data(y_train)
model = xgb.train(params, dtrain, obj=ordinal.xgb_objective)
# Get class probabilities
probs = ordinal.predict_proba(model.predict(dtest))
classes = ordinal.predict(model.predict(dtest))
Custom Objectives
The @auto_objective decorator turns any loss function into an XGBoost/LightGBM objective:
import xgboost as xgb
import lightgbm as lgb
import jax.numpy as jnp
from jaxboost import auto_objective
@auto_objective
def my_custom_loss(y_pred, y_true, **kwargs):
# Write your loss here - JAX computes grad/hess automatically
return (y_pred - y_true) ** 2
# Use with XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_custom_loss.xgb_objective)
# Use with LightGBM
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=my_custom_loss.lgb_objective)
# Pass parameters
model = xgb.train(
params, dtrain, num_boost_round=100,
obj=my_custom_loss.get_xgb_objective(alpha=0.5)
)
Multi-class Example
import xgboost as xgb
import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective
@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
# logits: (num_classes,), label: scalar
probs = jax.nn.softmax(logits)
return -jnp.log(probs[label] + 1e-7)
dtrain = xgb.DMatrix(X_train, label=y_train)
model = xgb.train(
{"num_class": 3, "max_depth": 4, "eta": 0.1},
dtrain,
num_boost_round=100,
obj=custom_multiclass.xgb_objective
)
Why jaxboost?
| Traditional Approach | jaxboost |
|---|---|
| Derive gradients by hand | Write loss, get gradients free |
| Derive Hessians by hand | Write loss, get Hessians free |
| Error-prone math | JAX autodiff is correct by construction |
| One loss = hours of work | One loss = 5 lines of code |
Requirements
- Python >= 3.10
- JAX >= 0.4.20
Documentation
Full documentation available at: https://jxucoder.github.io/jaxboost/
License
Apache 2.0
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxboost-0.3.0.tar.gz.
File metadata
- Download URL: jaxboost-0.3.0.tar.gz
- Upload date:
- Size: 299.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
29fdbea0054f9f3ae0f36ec25645c47b5aafb9247eec1937c80528b505ebf4b1
|
|
| MD5 |
c4579cb864988214652175c3a2a8f8e9
|
|
| BLAKE2b-256 |
5459095c32aaae558bcd43acbb952545a48d9c2496d5efb33970a390adb586ed
|
Provenance
The following attestation bundles were made for jaxboost-0.3.0.tar.gz:
Publisher:
publish.yml on jxucoder/jaxboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxboost-0.3.0.tar.gz -
Subject digest:
29fdbea0054f9f3ae0f36ec25645c47b5aafb9247eec1937c80528b505ebf4b1 - Sigstore transparency entry: 833563330
- Sigstore integration time:
-
Permalink:
jxucoder/jaxboost@3d1d659a26711c64a67424a2b353de1f8fc6b711 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/jxucoder
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3d1d659a26711c64a67424a2b353de1f8fc6b711 -
Trigger Event:
release
-
Statement type:
File details
Details for the file jaxboost-0.3.0-py3-none-any.whl.
File metadata
- Download URL: jaxboost-0.3.0-py3-none-any.whl
- Upload date:
- Size: 39.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1be97a6b538f51f036a262bdf962c3a580426ce54ea8eefda0a6f4e280836386
|
|
| MD5 |
cc5f30b1e4c3fc11d3b0d6c081adad35
|
|
| BLAKE2b-256 |
4c8be6a84ccc5b50d15787d5aacd1b34555236664a1cdbe4a20cec8604e389a9
|
Provenance
The following attestation bundles were made for jaxboost-0.3.0-py3-none-any.whl:
Publisher:
publish.yml on jxucoder/jaxboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxboost-0.3.0-py3-none-any.whl -
Subject digest:
1be97a6b538f51f036a262bdf962c3a580426ce54ea8eefda0a6f4e280836386 - Sigstore transparency entry: 833563332
- Sigstore integration time:
-
Permalink:
jxucoder/jaxboost@3d1d659a26711c64a67424a2b353de1f8fc6b711 -
Branch / Tag:
refs/tags/v0.3.0 - Owner: https://github.com/jxucoder
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3d1d659a26711c64a67424a2b353de1f8fc6b711 -
Trigger Event:
release
-
Statement type: