Skip to main content

"WARP loss for Pytorch. WSABIE"

Project description

WARP-Pytorch

An implementation of WARP loss which uses matrixes and stays on the GPU in PyTorch.

An implementation of WARP loss which uses matrixes and stays on the GPU in PyTorch.

This means instead of using a for-loop to find the first offending negative sample that ranks above our positive, we compute all of them at once. Only later do we find which sample is the first offender, and compute the loss with respect to this sample.

The advantage is that it can use the speedups that comes with GPU-usage.

When is WARP loss advantageous?

If you're ranking items or making models for recommendations, it's often advantageous to let your loss function directly optimize for this case. WARP loss looks at 1 explicit positive up against the implicit negative items that a user never sampled, and allows us to adjust weights of the network accordingly.

Install

pip install warp_loss

How to use

The loss function requires scores for both positive examples, and negative examples to be supplied, such as in the example below.

from torch import nn
import torch

class OurModel(nn.Module):
    def __init__(self, num_labels, emb_dim=10):
        super(OurModel, self).__init__()
        self.emb = nn.Embedding(num_labels, emb_dim)
        self.user_embs = nn.Embedding(1, emb_dim)

    def forward(self, pos, neg):
        batch_size = neg.size(0)
        one_user_vector = self.user_embs(torch.zeros(1).long())
        repeated_user_vector = one_user_vector.repeat((batch_size, 1)).view(batch_size, -1, 1)
        pos_res = torch.bmm(self.emb(pos), repeated_user_vector).squeeze(2)
        neg_res = torch.bmm(self.emb(neg), repeated_user_vector).squeeze(2)

        return pos_res, neg_res

num_labels = 100
model = OurModel(num_labels)
pos_labels = torch.randint(high=num_labels, size=(3,1)) # our five labels
neg_labels = torch.randint(high=num_labels, size=(3,2)) # a few random negatives per positive

pos_res, neg_res = model(pos_labels, neg_labels)
print('Positive Labels:', pos_labels)
print('Negative Labels:', neg_labels)
print('Model positive scores:', pos_res)
print('Model negative scores:', neg_res)
loss = warp_loss(pos_res, neg_res, num_labels=num_labels, device=torch.device('cpu'))
print('Loss:', loss)
loss.backward()
Positive Labels: tensor([[65],
        [94],
        [21]])
Negative Labels: tensor([[ 8, 45],
        [37, 93],
        [88, 84]])
Model positive scores: tensor([[-3.7806],
        [-1.9974],
        [-4.1741]], grad_fn=<SqueezeBackward1>)
Model negative scores: tensor([[-1.5696, -4.4905],
        [-1.9300, -0.3826],
        [ 2.4564, -2.1741]], grad_fn=<SqueezeBackward1>)
Loss: tensor(54.7226, grad_fn=<SumBackward0>)
print('We can also see that the gradient is only active for 2x the number of positive labels:', (model.emb.weight.grad.sum(1) != 0).sum().item())
print('Meaning we correctly discard the gradients for all other than the offending negative label.')
We can also see that the gradient is only active for 2x the number of positive labels: 6
Meaning we correctly discard the gradients for all other than the offending negative label.

Assumptions

The loss function assumes you have already sampled your negatives randomly.

As an example this could be done in your dataloader:

  1. Assume we have a total dataset of 100 items
  2. Select a positive sample with index 8
  3. Your negatives should be a random selection from 0-100 excluding 8.

Ex input to loss function: model scores for pos: [8] neg: [88, 3, 99, 7]

Currently only tested on PyTorch v0.4

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

warp_loss-0.0.1.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

warp_loss-0.0.1-py3-none-any.whl (6.0 kB view details)

Uploaded Python 3

File details

Details for the file warp_loss-0.0.1.tar.gz.

File metadata

  • Download URL: warp_loss-0.0.1.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for warp_loss-0.0.1.tar.gz
Algorithm Hash digest
SHA256 789ed5fc6abacbca45f5528a3031cc290eed8aea731c3a61e6f26333257b49a6
MD5 0ee058ee5c450ffacd8924aa9b4f7f01
BLAKE2b-256 d1d618b09de85c9bea4a70c3987866986bbab27f4e9ff758f77fb6441b0b0ff0

See more details on using hashes here.

File details

Details for the file warp_loss-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: warp_loss-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.4.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.7.4

File hashes

Hashes for warp_loss-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b99efa5ebbde617b41a9a4a6cb4917782202c27876099593d9701bf9f0117703
MD5 160ca9bb60dc172ed8f8a23908043cfe
BLAKE2b-256 6677105ff3d78ca5f07568613393695ed8e75e61fd153bed61f23d350d721233

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page