Skip to main content

Scalable Linear Shallow Autoencoder for Collaborative Filtering

Project description

License: MIT

ELSA

This is an official implementation of our paper Scalable Linear Shallow Autoencoder for Collaborative Filtering.

Requirements

PyTorch in version >=10.1 (along with compatible CUDA Toolkit) must be installed in the system. If not, one can install PyTorch with

pip install torch==1.10.2+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

Instalation

ELSA can be installed from pypi with:

pip install elsarec

Basic usage

from elsa import ELSA
import torch
import numpy as np

device = torch.device("cuda")

X_csr = ... # load your interaction matrix (scipy.sparse.csr_matrix with users in rows and items in columns)
X_test = ... # load your test data (scipy.sparse.csr_matrix with users in rows and items in columns)

items_cnt = X_csr.shape[1]
factors = 256 
num_epochs = 5
batch_size = 128

model = ELSA(n_items=items_cnt, device=device, n_dims=factors)
model.compile()

model.fit(X_csr, batch_size=batch_size, epochs=num_epochs)

# save item embeddings into np array
A = torch.nn.functional.normalize(model.get_items_embeddings(), dim=-1).cpu().numpy()

# get predictions in PyTorch
predictions = model.predict(X_test, batch_size=batch_size)

# get predictions in numpy
predictions = ((X_test @ A) @ (A.T)) - X_test

# find related items for a subset of items
itemids = np.array([id1, id2, ...])  # id1, id2 are indices of items in the X_csr
related = model.similar_items(N=100, batch_size=128, sources=itemids)

Notes

Reproducibility

Please get in touch with us if you want to reproduce the results from our paper.

Tensorflow users

We decided to implement ELSA in PyTorch, but implementation in TensorFlow is simple and straightforward. One can, for example, implement ELSA as a Keras layer:

class ELSA(tf.keras.layers.Layer):
    def __init__(self, latent, nr_of_items):
        super(ELSA, self).__init__()
        w_init = tf.keras.initializers.HeNormal()
        self.A = tf.Variable(
            initial_value=w_init(shape=(nr_of_items, latent), dtype="float32"),
            trainable=True,
        )
    
    def get_items_embeddings(self):
        A = tf.math.l2_normalize(self.A, axis=-1)
        return A.numpy()
    
    @tf.function
    def call(self, x):
        A = tf.math.l2_normalize(self.A, axis=-1)
        xA = tf.matmul(x, A, transpose_b=False)
        xAAT = tf.matmul(xA, feature, transpose_b=True)
        return xAAT - x

Licence

MIT licence

Troubleshooting

If you encounter a problem or have a question about ELSA, do not hesitate to create an issue and ask. In case of an implementation problem, please include the Python, PyTorch and CUDA versions in the description of the issue.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

elsarec-0.1.3.tar.gz (10.2 kB view details)

Uploaded Source

Built Distribution

elsarec-0.1.3-py3-none-any.whl (8.9 kB view details)

Uploaded Python 3

File details

Details for the file elsarec-0.1.3.tar.gz.

File metadata

  • Download URL: elsarec-0.1.3.tar.gz
  • Upload date:
  • Size: 10.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for elsarec-0.1.3.tar.gz
Algorithm Hash digest
SHA256 0c14833359660525f51bf8c55ef8dd2af0aaa95c721b9244415af19e36169e17
MD5 fbadb01a9e666ba9d118f5bee4d8249d
BLAKE2b-256 d1965aff02c54d86de707be6d96e1429cd261bc3ae37689b39718485849fd5a5

See more details on using hashes here.

File details

Details for the file elsarec-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: elsarec-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 8.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for elsarec-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c10cf5e9803c0cf15d08c7e11b4d528badeaf33191d6c732f7257f5fbb18a870
MD5 109db4dc9dfbe2419968c8d35df7b490
BLAKE2b-256 ad42b152516e174bb5a97910d367d6788619991b0c7396fe1436a86df5287eb9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page