Density EstimatioN using Masked AutoRegressive Flow
Project description
denmarf
Density EstimatioN using Masked AutoRegressive Flow
This package provides a scikit-learn
-like interface to perform density estimation using masked autoregressive flow. The current torch
-based implementation uses pytorch-flow as the backbone. A more performant re-implementation in jax
is in progress.
Requirements
- scipy-stack (numpy, scipy, matplotlib, pandas)
- pytorch
- CUDA (for GPU capability)
Installation
pip install denmarf
Usage
The interface is very similar to the KernelDensity module in scikit-learn. To perform a density estimation, one first initialize a DensityEstimate
object from denmarf.density
. Then one fit the data, which is a numpy ndarray of size (n_samples, n_features)), X
with the method DensityEstimate.fit(X)
. Once a model is trained, it can be used to generate new samples using DensityEstimate.sample()
, or to evaluate the density at arbitrary point with DensityEstimate.score_samples()
Initializing a DensityEstimate
object
To initialize a DensityEstimate
model, one can simply use
from denmarf import DensityEstimate
de = DensityEstimate()
Note that by default the model will try to use GPU whenever CUDA is available, and revert back to CPU if not available. To by-pass this behavior and use CPU even when GPU is available, use
from denmarf import DensityEstimate
de = DensityEstimate(device="cpu", use_cuda=False)
If multiple GPUs are available, one can specify which device to use by
from denmarf import DensityEstimate
de = DensityEstimate(device="cuda:2")
Fitting a bounded distribution
To faciliate the fitting performance for bounded distributions, logit transformations can be used to convert bounded distributions to unbound ones. denmarf
will automatically perform both the linear shifting and rescaling, as well as the actual logit transformation if the argument bounded
is set when initializing the model, and if the lower and upper bounds are given when calling .fit()
. When computing the probability density, the appropriate Jacobian is also computed.
For example,
from denmarf import DensityEstimate
# X is some np ndarray
de = DensityEstimate().fit(X, bounded=True, lower_bounds=..., upper_bounds=...)
Saving a trained model
After training a model, it can be saved (pickled) to disk for later usage. This can be done by using
de.save("filename_for_the_model.pkl")
Loading a saved model from disk
denmarf
has built-in support for loading a trained model saved to disk and reconstructing the model either to CPU or GPU (does not have to be the same architecture where the training was performed!). For example, let us say we have a model trained on a GPU and we want to evaluate the model on a CPU instead. This can be done by using
from denmarf import DensityEstimate
de = DensityEstimate.from_file(filename="filename_for_the_model.pkl")
The default behavior is always loading the model to CPU.
Contributing to denmarf
Contribution is always welcome!
You can use the issue tracker on github to submit a bug report, to request a new feature, or simply to ask a question about the code!
If you would like to make changes to the code, just submit a pull request on github.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file denmarf-0.3.2.tar.gz
.
File metadata
- Download URL: denmarf-0.3.2.tar.gz
- Upload date:
- Size: 11.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 334211e90512c742e05424df31d374b8c7f9d44a6e63786dbfd755ced5b94f33 |
|
MD5 | 17bd4467ba8de0385572f6bfdd1834af |
|
BLAKE2b-256 | bd197035214968b690e28f38df5d0c71853796492e86b64bfa3e28c50e81477f |
File details
Details for the file denmarf-0.3.2-py3-none-any.whl
.
File metadata
- Download URL: denmarf-0.3.2-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3c84fca0efa7e9cac44d43463ddb1c43642e5fb7682e42205a4d9a1a261651ec |
|
MD5 | 116ba767d9ed8fc75c5cd34cb1ad76bb |
|
BLAKE2b-256 | 464e48af3cf11248271e2e5ee0bb6946796c81402348cbd435ed49ba6b04b65b |