A pytorch package for Non-negative Matrix Factorization
Project description
Non-negative Matrix Fatorization in PyTorch
PyTorch is not only a good deep learning framework, but also a fast tool when it comes to matrix operations and convolutions on large data. A great example is PyTorchWavelets.
In this package I implement NMF, PLCA and their deconvolutional variations in PyTorch based on torch.nn.Module
,
so the models can be moved freely among CPU/GPU devices and utilize parallel computation of cuda.
Modules
NMF
Basic NMF and NMFD module minimizing beta-divergence using multiplicative update rules.
Part of the multiplier is obtained via torch.autograd
so the amount of codes is reduced and easy to maintain
(only the denominator is calculated).
The interface is similar to sklearn.decomposition.NMF
with some extra options.
NMF
: Original NMF algorithm.NMFD
: 1-D deconvolutional NMF algorithm.NMF2D
: 2-D deconvolutional NMF algorithm.NMF3D
: 3-D deconvolutional NMF algorithm.
PLCA
Basic PLCA and SIPLCA module using EM algorithm to minimize KL-divergence between the target distribution P(X) and the estimated distribution.
PLCA
: Original PLCA (Probabilistic Latent Component Analysis) algorithm.SIPLCA
: Shift-Invariant PLCA algorithm (similar to NMFD).SIPLCA2
: 2-D deconvolutional SIPLCA algorithm.SIPLCA3
: 3-D deconvolutional SIPLCA algorithm.
Usage
Here is a short example of decompose a spectrogram.
import torch
import librosa
from torchnmf import NMF
from torchnmf.metrics import KL_divergence
y, sr = librosa.load(librosa.util.example_audio_file())
y = torch.from_numpy(y)
windowsize = 2048
S = torch.stft(y, windowsize, window=torch.hann_window(windowsize)).pow(2).sum(2).sqrt().cuda()
R = 8 # number of components
net = NMF(S.shape, rank=R).cuda()
# run extremely fast on gpu
_, V = net.fit_transform(S) # fit to target matrix S
print(KL_divergence(V, S)) # KL divergence to S
A more detailed version can be found here, which redo this example with NMFD.
Compare to sklearn
The barchart shows the time cost per iteration with different beta-divergence. It is clear that pytorch-based NMF is faster than scipy-based NMF (sklearn) when beta != 2 (Euclidean distance), and run even faster when computation is done on GPU. The test is conducted on a Acer E5 laptop with i5-7200U CPU and GTX 950M GPU, PyTorch 0.4.1 (I found the cpu inference speed is much slower with version >= 1.0).
Installation
Using pip:
pip install git+http://github.com/yoyololicon/pytorch-NMFs
Or clone this repo and do:
python setup.py install
Requirements
- PyTorch >= 0.4.1
- tqdm
Tips
- If you notice significant slow down when operating on CPU, please flush denormal numbers by
torch.set_flush_denormal(True)
.
TODO
- Support sparse matrix.
- Regularization.
- NNDSVD initialization.
- 2/3-D deconvolutional module.
- PLCA.
- ipynb examples.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchnmf-0.3.tar.gz
.
File metadata
- Download URL: torchnmf-0.3.tar.gz
- Upload date:
- Size: 17.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b941362d76c64a0bde2c97781da92c3a965e1cc8cd7c0a7009b1be0278804199 |
|
MD5 | 8d633022a67b146e534555cc21a2b30c |
|
BLAKE2b-256 | 4c3875ef3169111336599ec5f6e7782a6bdd82ddd2b9a9ca0f21fbeb9695708a |
File details
Details for the file torchnmf-0.3-py3-none-any.whl
.
File metadata
- Download URL: torchnmf-0.3-py3-none-any.whl
- Upload date:
- Size: 20.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0.post20200714 requests-toolbelt/0.8.0 tqdm/4.47.0 CPython/3.8.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8158b99afae936bfe9553004dee341da76cb60dd815d64c77f63ba0a6a878bba |
|
MD5 | be8ae8a56ef18a482499d4e51ea082b3 |
|
BLAKE2b-256 | 2537275217743cff450b1a3a4c74a06e7c23c4a7263f3592502a8ce496aa2f76 |