Skip to main content

NumPy and PyTorch implementations of the beta-divergence loss.

Project description

Beta-Divergence Loss Implementations

This repository contains code for Python implementations of the beta-divergence loss, including implementations compatible NumPy and PyTorch.

Dependencies

This library is written in Python, and requires Python (with recommended version >= 3.9) to run. In addition to a working PyTorch installation, this library relies on the following libraries and recommended version numbers:

Installation

To install the latest stable release, use pip. Use the following command to install:

$ pip install beta-divergence-metrics

Usage

The numpybd.loss module contains two beta-divergence function implementations compatible with NumPy and NumPy arrays: one general beta-divergence between two arrays, and a beta-divergence implementation specific to non-negative matrix factorization (NMF). Similarly torchbd.loss module contains two beta-divergence class implementations compatible with PyTorch and PyTorch tensors. Beta-divergence implementations can be imported as follows:

# Import beta-divergence loss implementations
from numpybd.loss import *
from torchbd.loss import *

Beta-divergence between two NumPy arrays

To calculate the beta-divergence between a NumPy array a and a target or reference array b, use the beta_div loss function. The beta_div loss function can be used as follows:

# Calculate beta-divergence loss between array a and target array b
loss_val = beta_div(beta=0, reduction='mean')

Beta-divergence between two PyTorch tensors

To calculate the beta-divergence between tensor a and a target or reference tensor b, use the BetaDivLoss loss function. The BetaDivLoss loss function can be instantiated and used as follows:

# Instantiate beta-divergence loss object
loss_func = BetaDivLoss(beta=0, reduction='mean')

# Calculate beta-divergence loss between tensor a and target tensor b
loss_val = loss_func(input=a, target=b)

NMF beta-divergence between NumPy array of data and data reconstruction

To calculate the NMF-specific beta-divergence between a NumPy array of data matrix X and the product of a scores matrix H and a components matrix W, use the nmf_beta_div loss function. The nmf_beta_div loss function can beused as follows:

# Calculate beta-divergence loss between data matrix X (target or
# reference matrix) and matrix product of H and W
loss_val = nmf_beta_div(X=X, H=H, W=W, beta=0, reduction='mean')

NMF beta-divergence between PyTorch tensor of data and data reconstruction

To calculate the NMF-specific beta-divergence between a PyTorch tensor of data matrix X and the matrix product of a scores matrix H and a components matrix W, use the NMFBetaDivLoss loss class function. The NMFBetaDivLoss loss function can be instantiated and used as follows:

# Instantiate NMF beta-divergence loss object
loss_func = NMFBetaDivLoss(beta=0, reduction='mean')

# Calculate beta-divergence loss between data matrix X (target or
# reference matrix) and matrix product of H and W
loss_val = loss_func(X=X, H=H, W=W)

Choosing beta value

When instantiating beta-divergence loss objects, the value of beta should be chosen depending on data type and application. For NMF applications, a beta value of 0 (Itakura-Saito divergence) is recommemded. Integer values of beta correspond to the following divergences and loss functions:

Issue Tracking and Reports

Please use the GitHub issue tracker associated with this repository for issue tracking, filing bug reports, and asking general questions about the package or project.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

beta-divergence-metrics-0.0.2.tar.gz (7.8 kB view hashes)

Uploaded Source

Built Distribution

beta_divergence_metrics-0.0.2-py3-none-any.whl (13.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page