LossX: A JAX loss function library with config-based construction.
Project description
LossX
A functional JAX loss function library with PyTree-based composition and Hydra integration.
Features
- Pure Functional API - All loss functions are pure functions
- PyTree-Based Composition - Build complex multi-task losses from nested configurations
- Hydra Integration - Seamlessly integrate with Hydra configs
- Flexible Reductions - Custom reduction strategies for combining multiple losses
- Extensible - Easy to add custom loss functions
Installation
Using uv (recommended)
uv add lossx
Using pip
pip install lossx
Or just copy one of the source files into your project
Quick Start
Simple Loss Function Usage
Import and use loss functions directly:
import jax.numpy as jnp
from lossx.loss import cross_entropy, mse, gaussian_nll
# Mean Squared Error
true_y = jnp.array([1.0, 2.0, 3.0])
pred_y = jnp.array([1.1, 2.1, 3.1])
loss = mse(true_y, pred_y)
# Cross-Entropy with masking
true_labels = jnp.array([0, 1, 2, -100]) # -100 is masked
pred_logits = jnp.ones((4, 3))
loss = cross_entropy(true_labels, pred_logits, mask_index=-100)
# Gaussian NLL (for uncertainty estimation)
true_y = jnp.array([[1.0], [2.0]])
pred_y = jnp.array([[1.0, 0.1], [2.0, 0.1]]) # mean and variance
loss = gaussian_nll(true_y, pred_y, eps=1e-6)
Available Loss Functions
cross_entropy- Cross-entropy loss with masking and class weightsmse- Mean squared errorq_loss- Q-loss for classificationquantile_loss- Quantile loss for uncertainty estimationgaussian_nll- Gaussian negative log-likelihoodpenex- Penalized exponential losscontrastive- NT-Xent contrastive loss
PyTree-Based Loss Building
The real power of LossX comes from building complex losses using PyTree structures:
Single Loss
from lossx import build_loss
config = {
"target": "cross_entropy",
"mask_index": -100,
"weight": 1.0
}
loss_fn = build_loss(config)
loss = loss_fn(true_y, pred_y)
Multi-Task Learning (Dict of Losses)
import jax.numpy as jnp
config = {
"classification": {
"target": "cross_entropy",
"mask_index": -100,
"weight": 1.0
},
"regression": {
"target": "mse",
"weight": 0.5
},
"uncertainty": {
"target": "gaussian_nll",
"eps": 1e-6,
"weight": 0.3
}
}
# Default reduction is sum
loss_fn = build_loss(config, reduction=lambda losses: jnp.mean(jnp.array(losses)))
# Provide PyTree-structured inputs matching the config
losses = loss_fn(
true={
"classification": class_labels,
"regression": regression_targets,
"uncertainty": uncertainty_targets
},
pred={
"classification": class_logits,
"regression": regression_preds,
"uncertainty": uncertainty_preds
}
)
Composite Loss (List of Losses)
config = [
{"target": "mse", "weight": 1.0},
{"target": "cross_entropy", "mask_index": -100, "weight": 2.0}
]
loss_fn = build_loss(config, reduction=lambda losses: jnp.sum(jnp.array(losses)))
# Provide list-structured inputs
loss = loss_fn(
true=[regression_targets, class_labels],
pred=[regression_preds, class_logits]
)
Nested PyTree Structures
config = {
"main": [
{"target": "mse"},
{"target": "cross_entropy"}
],
"auxiliary": {
"task_a": {"target": "mse", "weight": 0.5},
"task_b": {"target": "gaussian_nll"}
}
}
loss_fn = build_loss(config)
# Inputs must match the same PyTree structure
loss = loss_fn(
true={
"main": [main_target1, main_target2],
"auxiliary": {
"task_a": aux_a_target,
"task_b": aux_b_target
}
},
pred={
"main": [main_pred1, main_pred2],
"auxiliary": {
"task_a": aux_a_pred,
"task_b": aux_b_pred
}
}
)
Custom Reductions
You can provide custom reduction functions to combine losses:
import jax.numpy as jnp
# Weighted reduction
weights = jnp.array([0.7, 0.3])
reduction = lambda losses: jnp.sum(jnp.array(losses) * weights)
# Max reduction (worst-case loss)
reduction = lambda losses: jnp.max(jnp.array(losses))
# Mean reduction
reduction = lambda losses: jnp.mean(jnp.array(losses))
loss_fn = build_loss(configs, reduction=reduction)
Hydra Integration
LossX is designed to work seamlessly with Hydra configs:
config.yaml:
loss:
target: cross_entropy
mask_index: -100
cls_weights: [1.0, 2.0, 1.0]
weight: 1.0
multi_task.yaml:
loss:
classification:
target: cross_entropy
mask_index: -100
weight: 1.0
regression:
target: mse
weight: 0.5
uncertainty:
target: gaussian_nll
eps: 1e-6
weight: 0.3
reduction: mean # or sum, weighted
Python code:
import hydra
from lossx import build_loss
from omegaconf import DictConfig
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
# Convert OmegaConf to dict
loss_config = dict(cfg.loss)
loss_fn = build_loss(loss_config)
# Use in training
loss = loss_fn(targets, predictions)
Custom Loss Functions
Register your own loss functions:
from lossx import register_loss
import jax.numpy as jnp
def custom_loss(true_y, pred_y, *, alpha=0.5, **kwargs):
"""Custom loss function."""
return alpha * jnp.mean((true_y - pred_y) ** 2)
# Register it
register_loss("custom", custom_loss)
# Now use it in configs
config = {"target": "custom", "alpha": 0.7}
loss_fn = build_loss(config)
API Reference
Loss Functions
All loss functions follow the signature:
def loss_fn(
true_y: PyTree[Array],
pred_y: PyTree[Array],
**kwargs
) -> Scalar
Builder Functions
build_loss(config, reduction=jnp.sum)
Build a loss function from a configuration PyTree.
Args:
config: Either a dict with"target"key (single loss) or a PyTree of such dictsreduction: Function to reduce multiple losses to a scalar (default:jnp.sum)
Returns:
- A loss function that takes
(true, pred)PyTrees and returns a scalar
register_loss(name, loss_fn)
Register a custom loss function.
Args:
name: Name to register underloss_fn: Loss function to register
Development
Running Tests
uv run pytest tests/ -v
Installing Development Dependencies
uv add --group dev
License
MIT License - see LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lossx-0.0.1.tar.gz.
File metadata
- Download URL: lossx-0.0.1.tar.gz
- Upload date:
- Size: 40.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d86c09478b3150092811483f1ba53508d2798d23caa09bd591e24898378e4ebe
|
|
| MD5 |
5d412531ce963d7415ca94c3ec69a494
|
|
| BLAKE2b-256 |
83a167051399a79d2076b530770ff377109f03ce15dbf45a6de8a1799afaf356
|
File details
Details for the file lossx-0.0.1-py3-none-any.whl.
File metadata
- Download URL: lossx-0.0.1-py3-none-any.whl
- Upload date:
- Size: 12.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8ffda6b068841079516ce4aadc3887abf8a97a947404b5d4d6b1a5f83d0f3c3e
|
|
| MD5 |
d3519f62f6af36481013e0e77197e784
|
|
| BLAKE2b-256 |
24e80b4344ccc543898911b12e61e79f8c5ec26f0f22f4db7ba3a0c7d65043b3
|