Next-generation differentiable gradient boosting with JAX
Project description
jaxboost
Differentiable gradient boosting in JAX.
⚠️ This is a personal learning project, very much a work in progress. There is no intention to replace production boosting libraries like XGBoost, LightGBM, or CatBoost. The main purpose is to learn JAX while rethinking gradient boosting from first principles. No guarantee of reliability or correctness for now. Issues welcome!
What it is
A gradient boosting implementation using soft (differentiable) oblivious trees. The entire model is trained end-to-end with gradient descent via optax, rather than the traditional greedy tree-building approach.
Key characteristics:
- Soft routing with sigmoid functions (trees are differentiable)
- Oblivious tree structure (same split at each level)
- Hyperplane splits (linear combinations of features)
- Runs on GPU via JAX
Installation
pip install jaxboost
Or from source:
git clone https://github.com/jxu/jaxboost.git
cd jaxboost
pip install -e .
Usage
from jaxboost import GBMTrainer, TrainerConfig
# Regression
trainer = GBMTrainer(task="regression")
model = trainer.fit(X_train, y_train)
predictions = model.predict(X_test)
# Classification
trainer = GBMTrainer(task="classification")
model = trainer.fit(X_train, y_train)
probabilities = model.predict(X_test)
classes = model.predict_class(X_test)
Configuration
config = TrainerConfig(
n_trees=20, # Number of trees
depth=4, # Tree depth
learning_rate=0.01, # Optimizer learning rate
epochs=500, # Training epochs
patience=50, # Early stopping patience
verbose=True, # Print progress
)
trainer = GBMTrainer(task="regression", config=config)
Requirements
- Python >= 3.10
- JAX >= 0.4.20
- optax >= 0.1.7
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxboost-0.1.0.tar.gz.
File metadata
- Download URL: jaxboost-0.1.0.tar.gz
- Upload date:
- Size: 121.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5e9f9e3601db35eb9624f7c4ac115e691702220610a7290284b91a19b4fead78
|
|
| MD5 |
8c1e6a5c0c9892e0163cbd81ca1f0e1a
|
|
| BLAKE2b-256 |
639f03baff91267ef3109677df9fd4bf2f06dca5bf54490169c32a4686967a0e
|
Provenance
The following attestation bundles were made for jaxboost-0.1.0.tar.gz:
Publisher:
publish.yml on jxucoder/jaxboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxboost-0.1.0.tar.gz -
Subject digest:
5e9f9e3601db35eb9624f7c4ac115e691702220610a7290284b91a19b4fead78 - Sigstore transparency entry: 759201896
- Sigstore integration time:
-
Permalink:
jxucoder/jaxboost@55d497f0292befa6a49109f05125e12ee7a19198 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/jxucoder
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@55d497f0292befa6a49109f05125e12ee7a19198 -
Trigger Event:
release
-
Statement type:
File details
Details for the file jaxboost-0.1.0-py3-none-any.whl.
File metadata
- Download URL: jaxboost-0.1.0-py3-none-any.whl
- Upload date:
- Size: 16.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
40d9e153544dc9a7544a157e7f76b9d379d6b75a2394d75c93f78427dddda7e7
|
|
| MD5 |
7e77fe08c80eca2b81784e961653b51d
|
|
| BLAKE2b-256 |
8ce621fa394549165b8a4d25b2b68f7447d64f25b8b3700e1eebb111475bebf0
|
Provenance
The following attestation bundles were made for jaxboost-0.1.0-py3-none-any.whl:
Publisher:
publish.yml on jxucoder/jaxboost
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxboost-0.1.0-py3-none-any.whl -
Subject digest:
40d9e153544dc9a7544a157e7f76b9d379d6b75a2394d75c93f78427dddda7e7 - Sigstore transparency entry: 759201932
- Sigstore integration time:
-
Permalink:
jxucoder/jaxboost@55d497f0292befa6a49109f05125e12ee7a19198 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/jxucoder
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@55d497f0292befa6a49109f05125e12ee7a19198 -
Trigger Event:
release
-
Statement type: