Skip to main content

Sampling from a Maximum-Likelihood fitted Multi-Gaussian distribution in TensorFlow 2.1

Project description

Multi-Gaussian Sampling

The parametrization of conditional probability density function plays a crucial role in the simulations used to produce quickly large samples of data emulating real-life datasets.

A very simple technique is based on the explicit modelling of the probability density function of the target dataset as a function of the conditions through the sum of kernel functions.

The packagemultigaussampler offers a simple Python3 implementation of this simple algorithm. The model of the pdf is obtained through a maximum likelihood fit of a probability density function obtained as sum of Gaussians, optimized using the TensorFlow implementation of the Adam optimizer.

The sampling of the pdf is also implemented in TensorFlow to provide efficient sampling on both CPU and GPU infrastructures.

Example code

The code snippet below trains a sampler on a random dataset and generates a random sample of y variables on top of the same X variables used for training.

import numpy as np

## Generate a random dataset as an example
nSamples = 1000 
X = np.random.uniform ( -20, 10,  (nSamples,4)).astype (np.float32) 
y = np.random.uniform ( 0, 1,     (nSamples,2)).astype (np.float32) 

#from multigaussampler import MGSampler
## Creates and configure the MGSampler object
gp = MGSampler(X,y) 

## Train the MGSampler on the training dataset
from tqdm import trange
progress_bar = trange ( 100 )
for iEpoch in progress_bar:
  l = gp.train ( X,y ) 
  progress_bar.set_description ( "Loss: %.1f " % l ) 

## Sample the obtained parametrization
gp.sample (X) 

Author

Lucio Anderlini (Istituto Nazionale di Fisica Nucleare)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for multigaussampler, version 0.1.1
Filename, size File type Python version Upload date Hashes
Filename, size multigaussampler-0.1.1-py3-none-any.whl (6.0 kB) File type Wheel Python version py3 Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page