Skip to main content

LossX: A JAX loss function library with config-based construction.

Project description

LossX

A functional JAX loss function library with PyTree-based composition and Hydra integration.

Features

  • Pure Functional API - All loss functions are pure functions
  • PyTree-Based Composition - Build complex multi-task losses from nested configurations
  • Hydra Integration - Seamlessly integrate with Hydra configs
  • Flexible Reductions - Custom reduction strategies for combining multiple losses
  • Extensible - Easy to add custom loss functions

Installation

Using uv (recommended)

uv add lossx

Using pip

pip install lossx

Or just copy one of the source files into your project

Quick Start

Simple Loss Function Usage

Import and use loss functions directly:

import jax.numpy as jnp
from lossx.loss import cross_entropy, mse, gaussian_nll

# Mean Squared Error
true_y = jnp.array([1.0, 2.0, 3.0])
pred_y = jnp.array([1.1, 2.1, 3.1])
loss = mse(true_y, pred_y)

# Cross-Entropy with masking
true_labels = jnp.array([0, 1, 2, -100])  # -100 is masked
pred_logits = jnp.ones((4, 3))
loss = cross_entropy(true_labels, pred_logits, mask_index=-100)

# Gaussian NLL (for uncertainty estimation)
true_y = jnp.array([[1.0], [2.0]])
pred_y = jnp.array([[1.0, 0.1], [2.0, 0.1]])  # mean and variance
loss = gaussian_nll(true_y, pred_y, eps=1e-6)

Available Loss Functions

  • cross_entropy - Cross-entropy loss with masking and class weights
  • mse - Mean squared error
  • q_loss - Q-loss for classification
  • quantile_loss - Quantile loss for uncertainty estimation
  • gaussian_nll - Gaussian negative log-likelihood
  • penex - Penalized exponential loss
  • contrastive - NT-Xent contrastive loss

PyTree-Based Loss Building

The real power of LossX comes from building complex losses using PyTree structures:

Single Loss

from lossx import build_loss

config = {
    "target": "cross_entropy",
    "mask_index": -100,
    "weight": 1.0
}

loss_fn = build_loss(config)
loss = loss_fn(true_y, pred_y)

Multi-Task Learning (Dict of Losses)

import jax.numpy as jnp

config = {
    "classification": {
        "target": "cross_entropy",
        "mask_index": -100,
        "weight": 1.0
    },
    "regression": {
        "target": "mse",
        "weight": 0.5
    },
    "uncertainty": {
        "target": "gaussian_nll",
        "eps": 1e-6,
        "weight": 0.3
    }
}

# Default reduction is sum
loss_fn = build_loss(config, reduction=lambda losses: jnp.mean(jnp.array(losses)))

# Provide PyTree-structured inputs matching the config
losses = loss_fn(
    true={
        "classification": class_labels,
        "regression": regression_targets,
        "uncertainty": uncertainty_targets
    },
    pred={
        "classification": class_logits,
        "regression": regression_preds,
        "uncertainty": uncertainty_preds
    }
)

Composite Loss (List of Losses)

config = [
    {"target": "mse", "weight": 1.0},
    {"target": "cross_entropy", "mask_index": -100, "weight": 2.0}
]

loss_fn = build_loss(config, reduction=lambda losses: jnp.sum(jnp.array(losses)))

# Provide list-structured inputs
loss = loss_fn(
    true=[regression_targets, class_labels],
    pred=[regression_preds, class_logits]
)

Nested PyTree Structures

config = {
    "main": [
        {"target": "mse"},
        {"target": "cross_entropy"}
    ],
    "auxiliary": {
        "task_a": {"target": "mse", "weight": 0.5},
        "task_b": {"target": "gaussian_nll"}
    }
}

loss_fn = build_loss(config)

# Inputs must match the same PyTree structure
loss = loss_fn(
    true={
        "main": [main_target1, main_target2],
        "auxiliary": {
            "task_a": aux_a_target,
            "task_b": aux_b_target
        }
    },
    pred={
        "main": [main_pred1, main_pred2],
        "auxiliary": {
            "task_a": aux_a_pred,
            "task_b": aux_b_pred
        }
    }
)

Custom Reductions

You can provide custom reduction functions to combine losses:

import jax.numpy as jnp

# Weighted reduction
weights = jnp.array([0.7, 0.3])
reduction = lambda losses: jnp.sum(jnp.array(losses) * weights)

# Max reduction (worst-case loss)
reduction = lambda losses: jnp.max(jnp.array(losses))

# Mean reduction
reduction = lambda losses: jnp.mean(jnp.array(losses))

loss_fn = build_loss(configs, reduction=reduction)

Hydra Integration

LossX is designed to work seamlessly with Hydra configs:

config.yaml:

loss:
  target: cross_entropy
  mask_index: -100
  cls_weights: [1.0, 2.0, 1.0]
  weight: 1.0

multi_task.yaml:

loss:
  classification:
    target: cross_entropy
    mask_index: -100
    weight: 1.0
  regression:
    target: mse
    weight: 0.5
  uncertainty:
    target: gaussian_nll
    eps: 1e-6
    weight: 0.3

reduction: mean  # or sum, weighted

Python code:

import hydra
from lossx import build_loss
from omegaconf import DictConfig

@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # Convert OmegaConf to dict
    loss_config = dict(cfg.loss)
    loss_fn = build_loss(loss_config)

    # Use in training
    loss = loss_fn(targets, predictions)

Custom Loss Functions

Register your own loss functions:

from lossx import register_loss
import jax.numpy as jnp

def custom_loss(true_y, pred_y, *, alpha=0.5, **kwargs):
    """Custom loss function."""
    return alpha * jnp.mean((true_y - pred_y) ** 2)

# Register it
register_loss("custom", custom_loss)

# Now use it in configs
config = {"target": "custom", "alpha": 0.7}
loss_fn = build_loss(config)

API Reference

Loss Functions

All loss functions follow the signature:

def loss_fn(
    true_y: PyTree[Array],
    pred_y: PyTree[Array],
    **kwargs
) -> Scalar

Builder Functions

build_loss(config, reduction=jnp.sum)

Build a loss function from a configuration PyTree.

Args:

  • config: Either a dict with "target" key (single loss) or a PyTree of such dicts
  • reduction: Function to reduce multiple losses to a scalar (default: jnp.sum)

Returns:

  • A loss function that takes (true, pred) PyTrees and returns a scalar

register_loss(name, loss_fn)

Register a custom loss function.

Args:

  • name: Name to register under
  • loss_fn: Loss function to register

Development

Running Tests

uv run pytest tests/ -v

Installing Development Dependencies

uv add --group dev

License

MIT License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lossx-0.0.tar.gz (40.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lossx-0.0-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file lossx-0.0.tar.gz.

File metadata

  • Download URL: lossx-0.0.tar.gz
  • Upload date:
  • Size: 40.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for lossx-0.0.tar.gz
Algorithm Hash digest
SHA256 f53ec4c8788eab90fc2d4ee1e2ef330a2804e0034b528854616040e9d6f6da25
MD5 0c24fc7dadc9556cd45e50d771aa30a9
BLAKE2b-256 c191e97f9fc78287101137d478e76148fe9da22fc6644e13e2073d689a28b445

See more details on using hashes here.

File details

Details for the file lossx-0.0-py3-none-any.whl.

File metadata

  • Download URL: lossx-0.0-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for lossx-0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fffd6c354155391b82d918415aa0284b862bfbec92278dc230e1dce3347fee87
MD5 2e7ca19195c61df2c3964083f47b1fc7
BLAKE2b-256 af2d5311477d47c4149e6767ba158022c8a7de1d5f3d79a700dec3b4fde52783

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page