Skip to main content

A fast and flexible library for gradient-based click models

Project description

CLAX: Fast and Flexible Neural Click Models in JAX

CLAX is a modular framework to build click models with gradient-based optimization in JAX and Flax NNX. CLAX is built to be fast, providing orders of magnitudes speed-up compared to classic EM-based frameworks, such as PyClick, by leveraging auto-diff and vectorized computations on GPUs.

The current documentation is available here and our pre-print here.

Installation

CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:

pip install clax-models

Basic Usage

CLAX is designed with sensible defaults, while also allowing for a high-level of customization. E.g., training a User Browsing Model in CLAX is as simple as:

from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw

model = UserBrowsingModel(
    query_doc_pairs=100_000_000, # Number of query-document pairs in the dataset
    positions=10, # Number of ranks per result page
    rngs=nnx.Rngs(42), # NNX random number generator
)
trainer = Trainer(
    optimizer=adamw(0.003),
    epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)

However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.

Reproducibility & Development

In the following, we cover how to reproduce the experiments from our paper or how to set up a fork of CLAX for development.

Initial Setup

  1. Install the UV package manager
    UV is a fast Python dependency manager. Install it from: https://github.com/astral-sh/uv

  2. Clone the CLAX repository

   git clone git@github.com:philipphager/clax.git
   cd clax/
  1. Install dependencies
   uv sync

This creates a virtual environment and installs all required dependencies.

Running Experiments

Our paper's experiments are located in the experiments/ directory. Each experiment contains:

  • A Python script with the experiment logic: main.py
  • A Hydra config file for configuration management: config.yaml
  • A bash script with all experimental configurations: main.sh

To run an experiment, follow these steps.

  1. Install experiment dependencies Installs additional packages for SLURM support and data analysis/plotting.
   uv sync --group experiments
  1. Download datasets
    Clone the Yandex and Baidu-ULTR datasets from HuggingFace. If you have GIT LFS installed, clone the datasets using:
   git lfs install
   git clone https://huggingface.co/datasets/philipphager/clax-datasets

Otherwise, download the datasets manually from HuggingFace. Note: The full datasets require 85GB of disk space. By default, CLAX expects datasets at ./clax-datasets/ relative to the project root. To use a custom path, update the dataset_dir parameter in each experiment's config.yaml:

   dataset_dir: /my/custom/path/to/datasets/
  1. Run an experiment script
    Navigate to your experiment of interest and run the bash script, e.g.:
   cd experiments/1-yandex-baseline/
   chmod +x ./main.sh
   ./main.sh

Optionally, you can run the script directly on a SLURM cluster using:

   sbatch ./main.sh +launcher=slurm

You can adjust the SLURM configuration to your cluster under: experiments/config/slurm.yaml

PyClick Experiments

Baseline experiments using PyClick require the PyPy interpreter and are maintained in a separate repository: https://github.com/philipphager/clax-baselines

Reference

If CLAX is useful to you, please consider citing our paper:

@misc{hager2025clax,
  title = {CLAX: Fast and Flexible Neural Click Models in JAX},
  author  = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
  year  = {2025},
  booktitle = {arxiv}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clax_models-0.1.1.tar.gz (30.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

clax_models-0.1.1-py3-none-any.whl (51.1 kB view details)

Uploaded Python 3

File details

Details for the file clax_models-0.1.1.tar.gz.

File metadata

  • Download URL: clax_models-0.1.1.tar.gz
  • Upload date:
  • Size: 30.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.2

File hashes

Hashes for clax_models-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b09eb3e0780f1f8172c0154564e4f9de9084e4e9e71e0ccf75dbbc5869de8f01
MD5 5773e3e01102738633fc8f139efe4957
BLAKE2b-256 ce2a7b3f5a4f186811d020efe00d4975fe9914bca7f0de597537208bcca76f00

See more details on using hashes here.

File details

Details for the file clax_models-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for clax_models-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bcd262ba556ba745a218043c84ce16bc293082902e72fef1ed57b8fedf4c0aa4
MD5 a0047ece79971d48579d6f2209c92768
BLAKE2b-256 21d8645a633187aaef5246b4cb30ad6521467675f072108ff2a17e40e12da2d8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page