Skip to main content

LossX: A JAX loss function library with config-based construction.

Project description

LossX

A functional JAX loss function library with PyTree-based composition and Hydra integration.

Features

  • Pure Functional API - All loss functions are pure functions
  • PyTree-Based Composition - Build complex multi-task losses from nested configurations
  • Hydra Integration - Seamlessly integrate with Hydra configs
  • Flexible Reductions - Custom reduction strategies for combining multiple losses
  • Extensible - Easy to add custom loss functions

Installation

Using uv (recommended)

uv add lossx

Using pip

pip install lossx

Or just copy one of the source files into your project

Quick Start

Simple Loss Function Usage

Import and use loss functions directly:

import jax.numpy as jnp
from lossx.loss import cross_entropy, mse, gaussian_nll

# Mean Squared Error
true_y = jnp.array([1.0, 2.0, 3.0])
pred_y = jnp.array([1.1, 2.1, 3.1])
loss = mse(true_y, pred_y)

# Cross-Entropy with masking
true_labels = jnp.array([0, 1, 2, -100])  # -100 is masked
pred_logits = jnp.ones((4, 3))
loss = cross_entropy(true_labels, pred_logits, mask_index=-100)

# Gaussian NLL (for uncertainty estimation)
true_y = jnp.array([[1.0], [2.0]])
pred_y = jnp.array([[1.0, 0.1], [2.0, 0.1]])  # mean and variance
loss = gaussian_nll(true_y, pred_y, eps=1e-6)

Available Loss Functions

  • cross_entropy - Cross-entropy loss with masking and class weights
  • mse - Mean squared error
  • q_loss - Q-loss for classification
  • quantile_loss - Quantile loss for uncertainty estimation
  • gaussian_nll - Gaussian negative log-likelihood
  • penex - Penalized exponential loss
  • contrastive - NT-Xent contrastive loss

PyTree-Based Loss Building

The real power of LossX comes from building complex losses using PyTree structures:

Single Loss

from lossx import build_loss

config = {
    "target": "cross_entropy",
    "mask_index": -100,
    "weight": 1.0
}

loss_fn = build_loss(config)
loss = loss_fn(true_y, pred_y)

Multi-Task Learning (Dict of Losses)

import jax.numpy as jnp

config = {
    "classification": {
        "target": "cross_entropy",
        "mask_index": -100,
        "weight": 1.0
    },
    "regression": {
        "target": "mse",
        "weight": 0.5
    },
    "uncertainty": {
        "target": "gaussian_nll",
        "eps": 1e-6,
        "weight": 0.3
    }
}

# Default reduction is sum
loss_fn = build_loss(config, reduction=lambda losses: jnp.mean(jnp.array(losses)))

# Provide PyTree-structured inputs matching the config
losses = loss_fn(
    true={
        "classification": class_labels,
        "regression": regression_targets,
        "uncertainty": uncertainty_targets
    },
    pred={
        "classification": class_logits,
        "regression": regression_preds,
        "uncertainty": uncertainty_preds
    }
)

Composite Loss (List of Losses)

config = [
    {"target": "mse", "weight": 1.0},
    {"target": "cross_entropy", "mask_index": -100, "weight": 2.0}
]

loss_fn = build_loss(config, reduction=lambda losses: jnp.sum(jnp.array(losses)))

# Provide list-structured inputs
loss = loss_fn(
    true=[regression_targets, class_labels],
    pred=[regression_preds, class_logits]
)

Nested PyTree Structures

config = {
    "main": [
        {"target": "mse"},
        {"target": "cross_entropy"}
    ],
    "auxiliary": {
        "task_a": {"target": "mse", "weight": 0.5},
        "task_b": {"target": "gaussian_nll"}
    }
}

loss_fn = build_loss(config)

# Inputs must match the same PyTree structure
loss = loss_fn(
    true={
        "main": [main_target1, main_target2],
        "auxiliary": {
            "task_a": aux_a_target,
            "task_b": aux_b_target
        }
    },
    pred={
        "main": [main_pred1, main_pred2],
        "auxiliary": {
            "task_a": aux_a_pred,
            "task_b": aux_b_pred
        }
    }
)

Custom Reductions

You can provide custom reduction functions to combine losses:

import jax.numpy as jnp

# Weighted reduction
weights = jnp.array([0.7, 0.3])
reduction = lambda losses: jnp.sum(jnp.array(losses) * weights)

# Max reduction (worst-case loss)
reduction = lambda losses: jnp.max(jnp.array(losses))

# Mean reduction
reduction = lambda losses: jnp.mean(jnp.array(losses))

loss_fn = build_loss(configs, reduction=reduction)

Hydra Integration

LossX is designed to work seamlessly with Hydra configs:

config.yaml:

loss:
  target: cross_entropy
  mask_index: -100
  cls_weights: [1.0, 2.0, 1.0]
  weight: 1.0

multi_task.yaml:

loss:
  classification:
    target: cross_entropy
    mask_index: -100
    weight: 1.0
  regression:
    target: mse
    weight: 0.5
  uncertainty:
    target: gaussian_nll
    eps: 1e-6
    weight: 0.3

reduction: mean  # or sum, weighted

Python code:

import hydra
from lossx import build_loss
from omegaconf import DictConfig

@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    # Convert OmegaConf to dict
    loss_config = dict(cfg.loss)
    loss_fn = build_loss(loss_config)

    # Use in training
    loss = loss_fn(targets, predictions)

Custom Loss Functions

Register your own loss functions:

from lossx import register_loss
import jax.numpy as jnp

def custom_loss(true_y, pred_y, *, alpha=0.5, **kwargs):
    """Custom loss function."""
    return alpha * jnp.mean((true_y - pred_y) ** 2)

# Register it
register_loss("custom", custom_loss)

# Now use it in configs
config = {"target": "custom", "alpha": 0.7}
loss_fn = build_loss(config)

API Reference

Loss Functions

All loss functions follow the signature:

def loss_fn(
    true_y: PyTree[Array],
    pred_y: PyTree[Array],
    **kwargs
) -> Scalar

Builder Functions

build_loss(config, reduction=jnp.sum)

Build a loss function from a configuration PyTree.

Args:

  • config: Either a dict with "target" key (single loss) or a PyTree of such dicts
  • reduction: Function to reduce multiple losses to a scalar (default: jnp.sum)

Returns:

  • A loss function that takes (true, pred) PyTrees and returns a scalar

register_loss(name, loss_fn)

Register a custom loss function.

Args:

  • name: Name to register under
  • loss_fn: Loss function to register

Development

Running Tests

uv run pytest tests/ -v

Installing Development Dependencies

uv add --group dev

License

MIT License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lossx-0.0.1.tar.gz (40.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lossx-0.0.1-py3-none-any.whl (12.0 kB view details)

Uploaded Python 3

File details

Details for the file lossx-0.0.1.tar.gz.

File metadata

  • Download URL: lossx-0.0.1.tar.gz
  • Upload date:
  • Size: 40.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for lossx-0.0.1.tar.gz
Algorithm Hash digest
SHA256 d86c09478b3150092811483f1ba53508d2798d23caa09bd591e24898378e4ebe
MD5 5d412531ce963d7415ca94c3ec69a494
BLAKE2b-256 83a167051399a79d2076b530770ff377109f03ce15dbf45a6de8a1799afaf356

See more details on using hashes here.

File details

Details for the file lossx-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: lossx-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.16 {"installer":{"name":"uv","version":"0.9.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for lossx-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8ffda6b068841079516ce4aadc3887abf8a97a947404b5d4d6b1a5f83d0f3c3e
MD5 d3519f62f6af36481013e0e77197e784
BLAKE2b-256 24e80b4344ccc543898911b12e61e79f8c5ec26f0f22f4db7ba3a0c7d65043b3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page