Skip to main content

A fast and flexible library for gradient-based click models

Project description

CLAX: Fast and Flexible Neural Click Models in JAX

CLAX is a modular framework to build click models with gradient-based optimization in JAX and Flax NNX. CLAX is built to be fast, providing orders of magnitudes speed-up compared to classic EM-based frameworks, such as PyClick, by leveraging auto-diff and vectorized computations on GPUs.

The current documentation is available here and our pre-print here.

Installation

CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:

pip install clax-models

Usage

CLAX is designed with sensible defaults, while also allowing for a high-level of customization. E.g., training a User Browsing Model in CLAX is as simple as:

from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw

model = UserBrowsingModel(
    query_doc_pairs=100_000_000, # Number of query-document pairs in the dataset
    positions=10, # Number of ranks per result page
    rngs=nnx.Rngs(42), # NNX random number generator
)
trainer = Trainer(
    optimizer=adamw(0.003),
    epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)

However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.

Development & Reproducibility

This guide covers how to set up CLAX for development and reproduce the experiments from our paper.

Initial Setup

  1. Install the UV package manager
    UV is a fast Python dependency manager. Install it from: https://github.com/astral-sh/uv

  2. Clone the CLAX repository

   git clone git@github.com:philipphager/clax.git
   cd clax/
  1. Install dependencies
   uv sync

This creates a virtual environment and installs all required dependencies.

Running Experiments

Our paper's experiments are located in the experiments/ directory. Each experiment contains:

  • A Python script with the experiment logic: main.py
  • A Hydra config file for configuration management: config.yaml
  • A bash script with all experimental configurations: main.sh

Setting up experiments

  1. Install experiment dependencies
   uv sync --group experiments

This installs additional packages needed for SLURM support and plotting.

  1. Download datasets
    Clone the Yandex and Baidu-ULTR datasets from HuggingFace. If you have GIT LFS installed, clone the datasets using:
   git lfs install
   git clone https://huggingface.co/datasets/philipphager/clax-datasets

Otherwise, download the datasets manually from HuggingFace. Note: The full datasets require 85GB of disk space. By default, CLAX expects datasets at ./clax-datasets/ relative to the project root. To use a custom path, update the dataset_dir parameter in each experiment's config.yaml:

   dataset_dir: /my/custom/path/to/datasets/
  1. Run the experiment script
    Navigate to your experiment of interest and run the bash script, e.g.:
   cd experiments/1-yandex-baseline/
   chmod +x ./main.sh
   ./main.sh

Optionally, you can run the script directly on a SLURM cluster using:

   sbatch ./main.sh +launcher=slurm

You can adjust the SLURM configuration to your cluster under: experiments/config/slurm.yaml

Baseline Experiments

Baseline experiments using PyClick require the PyPy interpreter and are maintained in a separate repository: https://github.com/philipphager/clax-baselines

Generate documentation

CLAX uses mkdocs to generate the documentation:

  1. Install development dependencies: uv sync --group dev
  2. Run mkdocs locally: uv run mkdocs serve

Reference

If CLAX is useful to you, please consider citing our paper:

@misc{hager2025clax,
  title = {CLAX: Fast and Flexible Neural Click Models in JAX},
  author  = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
  year  = {2025},
  booktitle = {arxiv}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clax_models-0.1.0.tar.gz (31.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

clax_models-0.1.0-py3-none-any.whl (51.2 kB view details)

Uploaded Python 3

File details

Details for the file clax_models-0.1.0.tar.gz.

File metadata

  • Download URL: clax_models-0.1.0.tar.gz
  • Upload date:
  • Size: 31.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.2

File hashes

Hashes for clax_models-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b67379e34b2f749a45ab99c5aa422915eba7bb80076ddbcef6391b49b2115c51
MD5 b4f0f1f598484b21dd542547353631ac
BLAKE2b-256 a3c0728dfb0d770c38d1c627c6ba54063f2223eebb7f0347e6fe45cdb2e54caf

See more details on using hashes here.

File details

Details for the file clax_models-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for clax_models-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cde89d0d6a86861f1262a41bc56a2862d527a63d5ac414d91cd97b3d07200fe9
MD5 ba3b14ef64ed17f678ec064dd844f8f0
BLAKE2b-256 1f155762bb56ac9c4a4848c5e37fab4a790b75751b34e73a2f3712b1320fa112

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page