Skip to main content

A pytorch package for Non-negative Matrix Factorization

Project description

Non-negative Matrix Fatorization in PyTorch

Documentation Status

PyTorch is not only a good deep learning framework, but also a fast tool when it comes to matrix operations and convolutions on large data. A great example is PyTorchWavelets.

In this package I implement NMF, PLCA and their deconvolutional variations in PyTorch based on torch.nn.Module, so the models can be moved freely among CPU/GPU devices and utilize parallel computation of cuda.

Modules

NMF

Basic NMF and NMFD module minimizing beta-divergence using multiplicative update rules. The multiplier is obtained via torch.autograd so the amount of codes is reduced and easy to maintain.

The interface is similar to sklearn.decomposition.NMF with some extra options.

  • NMF: Original NMF algorithm.
  • NMFD: 1-D deconvolutional NMF algorithm.
  • NMF2D: 2-D deconvolutional NMF algorithm.
  • NMF3D: 3-D deconvolutional NMF algorithm.

PLCA (not documented)

Basic PLCA and SIPLCA module using EM algorithm to minimize KL-divergence between the target distribution P(X) and the estimated distribution.

  • PLCA: Original PLCA (Probabilistic Latent Component Analysis) algorithm.
  • SIPLCA: Shift-Invariant PLCA algorithm (similar to NMFD).
  • SIPLCA2: 2-D deconvolutional SIPLCA algorithm.
  • SIPLCA3: 3-D deconvolutional SIPLCA algorithm.

NOTE

This module is currently not documented and still using the old function interface (before version 0.3). Will be updated and adopt in later version.


Usage

Here is a short example of decompose a spectrogram using deconvolutional NMF:

import torch
import librosa
from torchnmf.nmf import NMFD
from torchnmf.metrics import kl_div

y, sr = librosa.load(librosa.util.example_audio_file())
y = torch.from_numpy(y)
windowsize = 2048
S = torch.stft(y, windowsize, window=torch.hann_window(windowsize)).pow(2).sum(2).sqrt().cuda()
S = S.unsqueeze(0)

R = 8   # number of components
T = 400 # size of convolution window

net = NMFD(S.shape, rank=R, T=T).cuda()
# run extremely fast on gpu
net.fit(S)      # fit to target matrix S
V = net()
print(kl_div(V, S))        # KL divergence to S

A more detailed version can be found here. See our documentation to find out more usage of this package.

Compare to sklearn

The barchart shows the time cost per iteration with different beta-divergence. It shows that pytorch-based NMF has a much more constant process time across different beta values, which can take advantage when beta is not 0, 1, or 2. This is because our implementation use the same computational graph regardless which beta-divergence are we minimizing. It runs even faster when computation is done on GPU. The test is conducted on a Acer E5 laptop with i5-7200U CPU and GTX 950M GPU.

Installation

pip install torchnmf

Requirements

  • PyTorch
  • tqdm

Tips

  • If you notice significant slow down when operating on CPU, please flush denormal numbers by torch.set_flush_denormal(True).

TODO

  • Support sparse matrix.
  • Regularization.
  • NNDSVD initialization.
  • 2/3-D deconvolutional module.
  • PLCA.
  • Documentation.
  • ipynb examples.
  • Refactor PLCA module.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchnmf-0.3.1.tar.gz (17.0 kB view details)

Uploaded Source

Built Distribution

torchnmf-0.3.1-py3-none-any.whl (20.2 kB view details)

Uploaded Python 3

File details

Details for the file torchnmf-0.3.1.tar.gz.

File metadata

  • Download URL: torchnmf-0.3.1.tar.gz
  • Upload date:
  • Size: 17.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for torchnmf-0.3.1.tar.gz
Algorithm Hash digest
SHA256 3a64c77d8263ad19347fc192816b5d620c9f23dcd004d49215fdfbd2ed73a27f
MD5 ca5b86d1fa4026b838f0f5eb839edd6d
BLAKE2b-256 5c963c3757bcd65a8bb592663ac9d9425aa9c12f1cf7c1648c4b40549d1e0bcf

See more details on using hashes here.

File details

Details for the file torchnmf-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: torchnmf-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 20.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3

File hashes

Hashes for torchnmf-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d5a032cbbf229a8d0396c3da38f6493b52000558d9dcbe14e34b8d0961a4f7b6
MD5 49b31393dc99bdbe022b4c6a517c2ebe
BLAKE2b-256 820a62550ce1c62e45c3b87b6b992f51779887ef2e998046c3c233c9b7a6aa80

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page