Skip to main content

Next-generation differentiable gradient boosting with JAX

Project description

jaxboost

Tests Lint Python 3.10+ License

JAX autodiff for XGBoost/LightGBM objectives.

Write a loss function, get gradients and Hessians automatically. No manual derivation needed.

Works with XGBoost and LightGBM.

Install

pip install jaxboost

Quick Start

XGBoost

import xgboost as xgb
import jax.numpy as jnp
from jaxboost import auto_objective, focal_loss, huber, quantile

# Prepare your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}

# Built-in objectives - just use them
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.9).xgb_objective)

# Custom objective - write the loss, autodiff handles the rest
@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
    error = y_true - y_pred
    return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)

model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)

LightGBM

import lightgbm as lgb
from jaxboost import huber

train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}

model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)

Available Objectives

Regression

Objective Description
mse Mean squared error
huber Huber loss (robust to outliers)
pseudo_huber Smooth approximation of Huber loss
log_cosh Log-cosh loss
mae_smooth Smooth approximation of MAE
quantile(q) Quantile regression
asymmetric(alpha) Asymmetric squared error
tweedie(p) Tweedie deviance

Binary Classification

Objective Description
focal_loss Focal loss for imbalanced data
binary_crossentropy Standard log loss
weighted_binary_crossentropy Weighted binary cross-entropy
hinge_loss SVM-style hinge loss

Multi-class Classification

Objective Description
softmax_cross_entropy Standard multi-class
focal_multiclass Focal loss for multi-class
label_smoothing(eps) Label smoothing regularization
class_balanced Class-balanced loss

Survival Analysis

Objective Description
aft Accelerated failure time (log-normal)
weibull_aft Weibull AFT model

Ordinal Regression

Objective Description
ordinal_regression Cumulative Link Model (probit/logit)
qwk_ordinal QWK-aligned Expected Quadratic Error
squared_cdf_ordinal CRPS / Ranked Probability Score
hybrid_ordinal NLL + EQE hybrid

Multi-task Learning

Objective Description
multi_task_regression Multiple regression targets
multi_task_classification Multiple classification targets
multi_task_huber Multi-task Huber loss
multi_task_quantile Multi-task quantile loss
MaskedMultiTaskObjective Handle missing labels

Uncertainty Estimation

Objective Description
gaussian_nll Predict mean + variance
laplace_nll Predict median + scale

Ordinal Regression

XGBoost has no native ordinal objective. JAXBoost implements proper Cumulative Link Models:

from jaxboost import ordinal_regression, qwk_ordinal

# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_regression(n_classes=6, link='probit')
ordinal.init_thresholds_from_data(y_train)

model = xgb.train(params, dtrain, obj=ordinal.xgb_objective)

# Get class probabilities
probs = ordinal.predict_proba(model.predict(dtest))
classes = ordinal.predict(model.predict(dtest))

Custom Objectives

The @auto_objective decorator turns any loss function into an XGBoost/LightGBM objective:

import xgboost as xgb
import lightgbm as lgb
import jax.numpy as jnp
from jaxboost import auto_objective

@auto_objective
def my_custom_loss(y_pred, y_true, **kwargs):
    # Write your loss here - JAX computes grad/hess automatically
    return (y_pred - y_true) ** 2

# Use with XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_custom_loss.xgb_objective)

# Use with LightGBM
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=my_custom_loss.lgb_objective)

# Pass parameters
model = xgb.train(
    params, dtrain, num_boost_round=100,
    obj=my_custom_loss.get_xgb_objective(alpha=0.5)
)

Multi-class Example

import xgboost as xgb
import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective

@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
    # logits: (num_classes,), label: scalar
    probs = jax.nn.softmax(logits)
    return -jnp.log(probs[label] + 1e-7)

dtrain = xgb.DMatrix(X_train, label=y_train)
model = xgb.train(
    {"num_class": 3, "max_depth": 4, "eta": 0.1},
    dtrain,
    num_boost_round=100,
    obj=custom_multiclass.xgb_objective
)

Why jaxboost?

Traditional Approach jaxboost
Derive gradients by hand Write loss, get gradients free
Derive Hessians by hand Write loss, get Hessians free
Error-prone math JAX autodiff is correct by construction
One loss = hours of work One loss = 5 lines of code

Requirements

  • Python >= 3.10
  • JAX >= 0.4.20

Documentation

Full documentation available at: https://jxucoder.github.io/jaxboost/

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxboost-0.3.0.tar.gz (299.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxboost-0.3.0-py3-none-any.whl (39.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxboost-0.3.0.tar.gz.

File metadata

  • Download URL: jaxboost-0.3.0.tar.gz
  • Upload date:
  • Size: 299.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jaxboost-0.3.0.tar.gz
Algorithm Hash digest
SHA256 29fdbea0054f9f3ae0f36ec25645c47b5aafb9247eec1937c80528b505ebf4b1
MD5 c4579cb864988214652175c3a2a8f8e9
BLAKE2b-256 5459095c32aaae558bcd43acbb952545a48d9c2496d5efb33970a390adb586ed

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxboost-0.3.0.tar.gz:

Publisher: publish.yml on jxucoder/jaxboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxboost-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: jaxboost-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 39.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jaxboost-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1be97a6b538f51f036a262bdf962c3a580426ce54ea8eefda0a6f4e280836386
MD5 cc5f30b1e4c3fc11d3b0d6c081adad35
BLAKE2b-256 4c8be6a84ccc5b50d15787d5aacd1b34555236664a1cdbe4a20cec8604e389a9

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxboost-0.3.0-py3-none-any.whl:

Publisher: publish.yml on jxucoder/jaxboost

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page