Skip to main content

allRank is a framework for training learning-to-rank neural models

Project description

allRank : Learning to Rank in PyTorch

About

allRank is a PyTorch-based framework for training neural Learning-to-Rank (LTR) models, featuring implementations of:

  • common pointwise, pairwise and listwise loss functions
  • fully connected and Transformer-like scoring functions
  • commonly used evaluation metrics like Normalized Discounted Cumulative Gain (NDCG) and Mean Reciprocal Rank (MRR)
  • click-models for experiments on simulated click-through data

Motivation

allRank provides an easy and flexible way to experiment with various LTR neural network models and loss functions. It is easy to add a custom loss, and to configure the model and the training procedure. We hope that allRank will facilitate both research in neural LTR and its industrial applications.

Features

Implemented loss functions:

  1. ListNet (for binary and graded relevance)
  2. ListMLE
  3. RankNet
  4. Ordinal loss
  5. LambdaRank
  6. LambdaLoss
  7. ApproxNDCG
  8. RMSE
  9. NeuralNDCG (introduced in https://arxiv.org/pdf/2102.07831)

Getting started guide

To help you get started, we provide a run_example.sh script which generates dummy ranking data in libsvm format and trains a Transformer model on the data using provided example config.json config file. Once you run the script, the dummy data can be found in dummy_data directory and the results of the experiment in test_run directory. To run the example, Docker is required.

Configuring your model & training

To train your own model, configure your experiment in config.json file and run

python allrank/main.py --config_file_name allrank/config.json --run_id <the_name_of_your_experiment> --job_dir <the_place_to_save_results>

All the hyperparameters of the training procedure: i.e. model defintion, data location, loss and metrics used, training hyperparametrs etc. are controlled by the config.json file. We provide a template file config_template.json where supported attributes, their meaning and possible values are explained. Note that following MSLR-WEB30K convention, your libsvm file with training data should be named train.txt. You can specify the name of the validation dataset (eg. valid or test) in the config. Results will be saved under the path <job_dir>/results/<run_id>

Google Cloud Storage is supported in allRank as a place for data and job results.

Implementing custom loss functions

To experiment with your own custom loss, you need to implement a function that takes two tensors (model prediction and ground truth) as input and put it in the losses package, making sure it is exposed on a package level. To use it in training, simply pass the name (and args, if your loss method has some hyperparameters) of your function in the correct place in the config file:

"loss": {
    "name": "yourLoss",
    "args": {
        "arg1": val1,
        "arg2: val2
    }
  }

Applying click-model

To apply a click model you need to first have an allRank model trained. Next, run:

python allrank/rank_and_click.py --input-model-path <path_to_the_model_weights_file> --roles <comma_separated_list_of_ds_roles_to_process e.g. train,valid> --config_file_name allrank/config.json --run_id <the_name_of_your_experiment> --job_dir <the_place_to_save_results>

The model will be used to rank all slates from the dataset specified in config. Next - a click model configured in config will be applied and the resulting click-through dataset will be written under <job_dir>/results/<run_id> in a libSVM format. The path to the results directory may then be used as an input for another allRank model training.

Continuous integration

You should run scripts/ci.sh to verify that code passes style guidelines and unit tests.

Research

This framework was developed to support the research project Context-Aware Learning to Rank with Self-Attention. If you use allRank in your research, please cite:

@article{Pobrotyn2020ContextAwareLT,
  title={Context-Aware Learning to Rank with Self-Attention},
  author={Przemyslaw Pobrotyn and Tomasz Bartczak and Mikolaj Synowiec and Radoslaw Bialobrzeski and Jaroslaw Bojar},
  journal={ArXiv},
  year={2020},
  volume={abs/2005.10084}
}

Additionally, if you use the NeuralNDCG loss function, please cite the corresponding work, NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting:

@article{Pobrotyn2021NeuralNDCG,
  title={NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting},
  author={Przemyslaw Pobrotyn and Radoslaw Bialobrzeski},
  journal={ArXiv},
  year={2021},
  volume={abs/2102.07831}
}

License

Apache 2 License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

allRank-1.4.3.tar.gz (39.4 kB view details)

Uploaded Source

Built Distribution

allRank-1.4.3-py3-none-any.whl (63.8 kB view details)

Uploaded Python 3

File details

Details for the file allRank-1.4.3.tar.gz.

File metadata

  • Download URL: allRank-1.4.3.tar.gz
  • Upload date:
  • Size: 39.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.23.0 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.7

File hashes

Hashes for allRank-1.4.3.tar.gz
Algorithm Hash digest
SHA256 6b83220cd3fb8f40890a381028ec0aeda36d30bacf0f001d8342d77d6ea395bb
MD5 9e93da934c235aacd1d4c1dea4c85baf
BLAKE2b-256 94f0d24e9be9d0c9ab9496739b71eb1db57da430c12b89633b2dd76a391cef29

See more details on using hashes here.

File details

Details for the file allRank-1.4.3-py3-none-any.whl.

File metadata

  • Download URL: allRank-1.4.3-py3-none-any.whl
  • Upload date:
  • Size: 63.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.23.0 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.7

File hashes

Hashes for allRank-1.4.3-py3-none-any.whl
Algorithm Hash digest
SHA256 74a5d31e3aa6eb269162ecc3bcad89ffd550c984908bb3d96e402ad6eb22f5e2
MD5 37143db24acb39bb758fa80b65158960
BLAKE2b-256 0b7dd5a211afdef374c5996aedd1d347c3268689ff27a96c41ae4d5881f21796

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page