Skip to main content

A pytorch package for Non-negative Matrix Factorization

Project description

Non-negative Matrix Fatorization in PyTorch

PyTorch is not only a good deep learning framework, but also a fast tool when it comes to matrix operations and convolutions on large data. A great example is PyTorchWavelets.

In this package I implement NMF, PLCA and their deconvolutional variations in PyTorch based on torch.nn.Module, so the models can be moved freely among CPU/GPU devices and utilize parallel computation of cuda.

Modules

NMF

Basic NMF and NMFD module minimizing beta-divergence using multiplicative update rules. Part of the multiplier is obtained via torch.autograd so the amount of codes is reduced and easy to maintain (only the denominator is calculated).

The interface is similar to sklearn.decomposition.NMF with some extra options.

  • NMF: Original NMF algorithm.
  • NMFD: 1-D deconvolutional NMF algorithm.
  • NMF2D: 2-D deconvolutional NMF algorithm.
  • NMF3D: 3-D deconvolutional NMF algorithm.

PLCA

Basic PLCA and SIPLCA module using EM algorithm to minimize KL-divergence between the target distribution P(X) and the estimated distribution.

  • PLCA: Original PLCA (Probabilistic Latent Component Analysis) algorithm.
  • SIPLCA: Shift-Invariant PLCA algorithm (similar to NMFD).
  • SIPLCA2: 2-D deconvolutional SIPLCA algorithm.
  • SIPLCA3: 3-D deconvolutional SIPLCA algorithm.

Usage

Here is a short example of decompose a spectrogram.

import torch
import librosa
from torchnmf import NMF
from torchnmf.metrics import KL_divergence

y, sr = librosa.load(librosa.util.example_audio_file())
y = torch.from_numpy(y)
windowsize = 2048
S = torch.stft(y, windowsize, window=torch.hann_window(windowsize)).pow(2).sum(2).sqrt().cuda()

R = 8   # number of components

net = NMF(S.shape, rank=R).cuda()
# run extremely fast on gpu
_, V = net.fit_transform(S)      # fit to target matrix S
print(KL_divergence(V, S))        # KL divergence to S

A more detailed version can be found here, which redo this example with NMFD.

Compare to sklearn

The barchart shows the time cost per iteration with different beta-divergence. It is clear that pytorch-based NMF is faster than scipy-based NMF (sklearn) when beta != 2 (Euclidean distance), and run even faster when computation is done on GPU. The test is conducted on a Acer E5 laptop with i5-7200U CPU and GTX 950M GPU, PyTorch 0.4.1 (I found the cpu inference speed is much slower with version >= 1.0).

Installation

Using pip:

pip install git+http://github.com/yoyololicon/pytorch-NMFs

Or clone this repo and do:

python setup.py install

Requirements

  • PyTorch >= 0.4.1
  • tqdm

Tips

  • If you notice significant slow down when operating on CPU, please flush denormal numbers by torch.set_flush_denormal(True).

TODO

  • Support sparse matrix.
  • Regularization.
  • NNDSVD initialization.
  • 2/3-D deconvolutional module.
  • PLCA.
  • ipynb examples.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchnmf-0.3.tar.gz (17.4 kB view details)

Uploaded Source

Built Distribution

torchnmf-0.3-py3-none-any.whl (20.1 kB view details)

Uploaded Python 3

File details

Details for the file torchnmf-0.3.tar.gz.

File metadata

  • Download URL: torchnmf-0.3.tar.gz
  • Upload date:
  • Size: 17.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for torchnmf-0.3.tar.gz
Algorithm Hash digest
SHA256 b941362d76c64a0bde2c97781da92c3a965e1cc8cd7c0a7009b1be0278804199
MD5 8d633022a67b146e534555cc21a2b30c
BLAKE2b-256 4c3875ef3169111336599ec5f6e7782a6bdd82ddd2b9a9ca0f21fbeb9695708a

See more details on using hashes here.

File details

Details for the file torchnmf-0.3-py3-none-any.whl.

File metadata

  • Download URL: torchnmf-0.3-py3-none-any.whl
  • Upload date:
  • Size: 20.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for torchnmf-0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 8158b99afae936bfe9553004dee341da76cb60dd815d64c77f63ba0a6a878bba
MD5 be8ae8a56ef18a482499d4e51ea082b3
BLAKE2b-256 2537275217743cff450b1a3a4c74a06e7c23c4a7263f3592502a8ce496aa2f76

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page