Skip to main content

An MDN Layer for Keras using TensorFlow's distributions module

Project description

Keras Mixture Density Network Layer

Coverage Status Build and test keras-mdn-layer MIT License DOI PyPI version

A mixture density network (MDN) Layer for Keras using TensorFlow's distributions module. This makes it a bit more simple to experiment with neural networks that predict multiple real-valued variables that can take on multiple equally likely values.

This layer can help build MDN-RNNs similar to those used in RoboJam, Sketch-RNN, handwriting generation, and maybe even world models. You can do a lot of cool stuff with MDNs!

One benefit of this implementation is that you can predict any number of real-values. TensorFlow's Mixture, Categorical, and MultivariateNormalDiag distribution functions are used to generate the loss function (the probability density function of a mixture of multivariate normal distributions with a diagonal covariance matrix). In previous work, the loss function has often been specified by hand which is fine for 1D or 2D prediction, but becomes a bit more annoying after that.

Two important functions are provided for training and prediction:

  • get_mixture_loss_func(output_dim, num_mixtures): This function generates a loss function with the correct output dimensions and number of mixtures.
  • sample_from_output(params, output_dim, num_mixtures, temp=1.0, sigma_temp=1.0): This function samples from the mixture distribution output by the model.

Installation

This project requires Python 3.11+, TensorFlow 2.16+, and TensorFlow Probability 0.24+. You can install this package from PyPI via pip like so:

python3 -m pip install keras-mdn-layer

And finally, import the module in Python: import keras_mdn_layer as mdn

Alternatively, you can clone or download this repository and then install via poetry install.

Tested Configurations

This library is tested against the following platform, Python, and TensorFlow combinations:

TensorFlow TF Probability tf-keras Python Platforms
2.15.1 0.23.0 2.15.1 3.11 Ubuntu
2.16.2 0.24.0 2.16.0 3.11, 3.12 Ubuntu, macOS
2.18.1 0.25.0 3.11, 3.12 Ubuntu, macOS
2.20.0 0.25.0 3.11–3.13 Ubuntu, macOS, Windows

Other combinations may work but are not regularly tested in CI.

Build

This project builds using poetry. To build a wheel use poetry build.

Examples

Some examples are provided in the notebooks directory.

To run these using poetry, run poetry install and then open jupyter poetry run jupyter lab.

There's scripts for fitting multivalued functions, a standard MDN toy problem:

Keras MDN Demo

There's also a script for generating fake kanji characters:

kanji test 1

And finally, for learning how to generate musical touch-screen performances with a temporal component:

Robojam Model Examples

How to use

The MDN layer should be the last in your network and you should use get_mixture_loss_func to generate a loss function. Here's an example of a simple network with one Dense layer followed by the MDN.

from tensorflow import keras
import keras_mdn_layer as mdn

N_HIDDEN = 15  # number of hidden units in the Dense layer
N_MIXES = 10  # number of mixture components
OUTPUT_DIMS = 2  # number of real-values predicted by each mixture component

model = keras.Sequential()
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))
model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

Fit as normal:

history = model.fit(x=x_train, y=y_train)

The predictions from the network are parameters of the mixture models, so you have to apply the sample_from_output function to generate samples.

y_test = model.predict(x_test)
y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0)

See the notebooks directory for examples in jupyter notebooks!

Load/Save Model

Saving models is straight forward:

model.save('test_save.keras')

But loading requires custom_objects to be filled with the MDN layer, and a loss function with the appropriate parameters:

m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})

Acknowledgements

References

  1. Christopher M. Bishop. 1994. Mixture Density Networks. Technical Report NCRG/94/004. Neural Computing Research Group, Aston University. http://publications.aston.ac.uk/373/
  2. Axel Brando. 2017. Mixture Density Networks (MDN) for distribution and uncertainty estimation. Master’s thesis. Universitat Politècnica de Catalunya.
  3. A. Graves. 2013. Generating Sequences With Recurrent Neural Networks. ArXiv e-prints (Aug. 2013). https://arxiv.org/abs/1308.0850
  4. David Ha and Douglas Eck. 2017. A Neural Representation of Sketch Drawings. ArXiv e-prints (April 2017). https://arxiv.org/abs/1704.03477
  5. Charles P. Martin and Jim Torresen. 2018. RoboJam: A Musical Mixture Density Network for Collaborative Touchscreen Interaction. In Evolutionary and Biologically Inspired Music, Sound, Art and Design: EvoMUSART ’18, A. Liapis et al. (Ed.). Lecture Notes in Computer Science, Vol. 10783. Springer International Publishing. DOI:10.1007/9778-3-319-77583-8_11

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_mdn_layer-0.6.0.tar.gz (14.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

keras_mdn_layer-0.6.0-py3-none-any.whl (14.6 kB view details)

Uploaded Python 3

File details

Details for the file keras_mdn_layer-0.6.0.tar.gz.

File metadata

  • Download URL: keras_mdn_layer-0.6.0.tar.gz
  • Upload date:
  • Size: 14.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.3 Darwin/25.3.0

File hashes

Hashes for keras_mdn_layer-0.6.0.tar.gz
Algorithm Hash digest
SHA256 9548eeb965a3f63651c1f28b1236bab583b0e8be7a0afd9c0c35a2ee123f5a96
MD5 25bbb16f68d723490b250740efa9cf0a
BLAKE2b-256 5e3b0c496db55e8ca61df766c9e5eeab2da1d3851e851ccb2809b41180ff7bf5

See more details on using hashes here.

File details

Details for the file keras_mdn_layer-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: keras_mdn_layer-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 14.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.11.3 Darwin/25.3.0

File hashes

Hashes for keras_mdn_layer-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0ec4776a8c28021fcfd4a49200d8ef1801f1529f5288982e84a02a899ee1133b
MD5 1a5298abefda429de4e60b1c490ec656
BLAKE2b-256 7f11b9c285e85ae7f9aeaafab981a045ad9ced4b7994d5aab1bf9fac01a53e45

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page