A fast and flexible library for gradient-based click models
Project description
CLAX: Fast and Flexible Neural Click Models in JAX
CLAX is a modular framework to build click models with gradient-based optimization in JAX and Flax NNX. CLAX is built to be fast, providing orders of magnitudes speed-up compared to classic EM-based frameworks, such as PyClick, by leveraging auto-diff and vectorized computations on GPUs.
The current documentation is available here and our pre-print here.
Installation
CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:
pip install clax-models
Basic Usage
CLAX is designed with sensible defaults, while also allowing for a high-level of customization. E.g., training a User Browsing Model in CLAX is as simple as:
from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw
model = UserBrowsingModel(
query_doc_pairs=100_000_000, # Number of query-document pairs in the dataset
positions=10, # Number of ranks per result page
rngs=nnx.Rngs(42), # NNX random number generator
)
trainer = Trainer(
optimizer=adamw(0.003),
epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)
However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.
Reproducibility & Development
In the following, we cover how to reproduce the experiments from our paper or how to set up a fork of CLAX for development.
Initial Setup
-
Install the UV package manager
UV is a fast Python dependency manager. Install it from: https://github.com/astral-sh/uv -
Clone the CLAX repository
git clone git@github.com:philipphager/clax.git
cd clax/
- Install dependencies
uv sync
This creates a virtual environment and installs all required dependencies.
Running Experiments
Our paper's experiments are located in the experiments/ directory. Each experiment contains:
- A Python script with the experiment logic:
main.py - A Hydra config file for configuration management:
config.yaml - A bash script with all experimental configurations:
main.sh
To run an experiment, follow these steps.
- Install experiment dependencies Installs additional packages for SLURM support and data analysis/plotting.
uv sync --group experiments
- Download datasets
Clone the Yandex and Baidu-ULTR datasets from HuggingFace. If you have GIT LFS installed, clone the datasets using:
git lfs install
git clone https://huggingface.co/datasets/philipphager/clax-datasets
Otherwise, download the datasets manually from HuggingFace.
Note: The full datasets require 85GB of disk space. By default, CLAX expects datasets at ./clax-datasets/ relative to the project root. To use a custom path, update the dataset_dir parameter in each experiment's config.yaml:
dataset_dir: /my/custom/path/to/datasets/
- Run an experiment script
Navigate to your experiment of interest and run the bash script, e.g.:
cd experiments/1-yandex-baseline/
chmod +x ./main.sh
./main.sh
Optionally, you can run the script directly on a SLURM cluster using:
sbatch ./main.sh +launcher=slurm
You can adjust the SLURM configuration to your cluster under: experiments/config/slurm.yaml
PyClick Experiments
Baseline experiments using PyClick require the PyPy interpreter and are maintained in a separate repository: https://github.com/philipphager/clax-baselines
Reference
If CLAX is useful to you, please consider citing our paper:
@misc{hager2025clax,
title = {CLAX: Fast and Flexible Neural Click Models in JAX},
author = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
year = {2025},
booktitle = {arxiv}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file clax_models-0.1.1.tar.gz.
File metadata
- Download URL: clax_models-0.1.1.tar.gz
- Upload date:
- Size: 30.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b09eb3e0780f1f8172c0154564e4f9de9084e4e9e71e0ccf75dbbc5869de8f01
|
|
| MD5 |
5773e3e01102738633fc8f139efe4957
|
|
| BLAKE2b-256 |
ce2a7b3f5a4f186811d020efe00d4975fe9914bca7f0de597537208bcca76f00
|
File details
Details for the file clax_models-0.1.1-py3-none-any.whl.
File metadata
- Download URL: clax_models-0.1.1-py3-none-any.whl
- Upload date:
- Size: 51.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bcd262ba556ba745a218043c84ce16bc293082902e72fef1ed57b8fedf4c0aa4
|
|
| MD5 |
a0047ece79971d48579d6f2209c92768
|
|
| BLAKE2b-256 |
21d8645a633187aaef5246b4cb30ad6521467675f072108ff2a17e40e12da2d8
|