Skip to main content

Learning-to-Rank using JAX.

Project description

🦖 Rax: Learning-to-Rank using JAX

Docs PyPI License

Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking losses and metrics to be used with JAX. It provides the following functionality:

  • Ranking losses (rax.*_loss): rax.softmax_loss, rax.pairwise_logistic_loss, ...
  • Ranking metrics (rax.*_metric): rax.mrr_metric, rax.ndcg_metric, ...
  • Transformations (rax.*_t12n): rax.approx_t12n, rax.gumbel_t12n, ...

Ranking

A ranking problem is different from traditional classification/regression problems in that its objective is to optimize for the correctness of the relative order of a list of examples (e.g., documents) for a given context (e.g., a query). Rax provides support for ranking problems within the JAX ecosystem. It can be used in, but is not limited to, the following applications:

  • Search: ranking a list of documents with respect to a query.
  • Recommendation: ranking a list of items given a user as context.
  • Question Answering: finding the best answer from a list of candidates.
  • Dialogue System: finding the best response from a list of responses.

Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute various ranking losses and metrics:

import jax.numpy as jnp
import rax

scores = jnp.array([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.array([1.0,  0.0, 0.0])  # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)  # computes a ranking metric.
# 0.63092977

rax.pairwise_hinge_loss(scores, labels)  # computes a ranking loss.
# 2.1

All of the Rax losses and metrics are purely functional and compose well with standard JAX transformations. Additionally, Rax provides ranking-specific transformations so you can build new ranking losses. An example is rax.approx_t12n, which can be used to transform any (non-differentiable) ranking metric into a differentiable loss. For example:

loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)  # differentiable approx ndcg loss.
# -0.63282484

jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.
# [-0.01276882  0.00549765  0.00727116]

Installation

See https://github.com/google/jax#installation for instructions on installing JAX.

We suggest installing the latest stable version of Rax by running:

$ pip install rax

Examples

See the examples/ directory for complete examples on how to use Rax.

Citing Rax

If you use Rax, please consider citing our paper:

@inproceedings{jagerman2022rax,
  title = {Rax: Composable Learning-to-Rank using JAX},
  author  = {Rolf Jagerman and Xuanhui Wang and Honglei Zhuang and Zhen Qin and
  Michael Bendersky and Marc Najork},
  year  = {2022},
  booktitle = {Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rax-0.3.0.tar.gz (59.5 kB view details)

Uploaded Source

Built Distribution

rax-0.3.0-py3-none-any.whl (78.9 kB view details)

Uploaded Python 3

File details

Details for the file rax-0.3.0.tar.gz.

File metadata

  • Download URL: rax-0.3.0.tar.gz
  • Upload date:
  • Size: 59.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for rax-0.3.0.tar.gz
Algorithm Hash digest
SHA256 9d0b86a22cfad2d87129d30403057c076b2505e7b95f626ab11f02416af240da
MD5 51f6ce64c76acb65ef0a6e66fca70ea8
BLAKE2b-256 52c6e949ca17e4dfb551f0734c537e4898645ee2d796c1df4cc1a38bbb80cd93

See more details on using hashes here.

File details

Details for the file rax-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: rax-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 78.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for rax-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 527654d733530b66595e386d979a347e9affb1dda68cf837455145f66950e240
MD5 666a2239fd44410e59d2dbde69d158c8
BLAKE2b-256 9ccfaa093e63ded3d3ae419707e168cbce68d74c7b3815df58e17a6e84e7e879

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page