Skip to main content

Mixture Density Networks in scikit-learn

Project description

scikit-mdn

A mixture density network, by PyTorch, for scikit-learn

This project started as part of a live-stream that is part of the probabl outreach effort on YouTube. If you want to watch the relevant livestreams they can be found here and here.

Usage

To get this tool working locally you will first need to install it:

python -m pip install scikit-mdn

Then you can use it in your code. Here is a small demo example.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from skmdn import MixtureDensityEstimator

# Generate dataset
n_samples = 1000
X_full, _ = make_moons(n_samples=n_samples, noise=0.1)
X = X_full[:, 0].reshape(-1, 1)  # Use only the first column as input
Y = X_full[:, 1].reshape(-1, 1)  # Predict the second column

# Add some noise to Y to make the problem more suitable for MDN
Y += 0.1 * np.random.randn(n_samples, 1)

# Fit the model
mdn = MixtureDensityEstimator()
mdn.fit(X, Y)

# Predict some quantiles on the train set 
means, quantiles = mdn.predict(X, quantiles=[0.01, 0.1, 0.9, 0.99], resolution=100000)
plt.scatter(X, Y)
plt.scatter(X, quantiles[:, 0], color='orange')
plt.scatter(X, quantiles[:, 1], color='green')
plt.scatter(X, quantiles[:, 2], color='green')
plt.scatter(X, quantiles[:, 3], color='orange')
plt.scatter(X, means, color='red')

This is what the chart looks like:

Example chart

API Documentation

You can find the API documentation on GitHub pages, found here:

https://koaning.github.io/scikit-mdn/

More depth

If you appreciate a glimpse of the internals, you may want to play around with the mdn.ipynb notebook that contains a Jupyter widget.

Example chart

Extra resources

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_mdn-0.0.3.tar.gz (5.1 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

scikit_mdn-0.0.3-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

scikit_mdn-0.0.3-py2.py3-none-any.whl (5.1 kB view details)

Uploaded Python 2Python 3

File details

Details for the file scikit_mdn-0.0.3.tar.gz.

File metadata

  • Download URL: scikit_mdn-0.0.3.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.4

File hashes

Hashes for scikit_mdn-0.0.3.tar.gz
Algorithm Hash digest
SHA256 9872cac775cdfe11862458814f32cf0dc474ca2dc2b02a65a6925e5335be2990
MD5 f4a6fe80728d8473aec8ab0268c1904a
BLAKE2b-256 48c56576b17e402a12ba53ad46b81b87e92826dc3c894c1fe8c852cef5f06a31

See more details on using hashes here.

File details

Details for the file scikit_mdn-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for scikit_mdn-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 cc96f9878e47e1d510e07f8476546ab240efa65f014f033950f3994e1c292de0
MD5 6ed9685355fa1290b4db789488a88d5e
BLAKE2b-256 837006d9f57e24c309cb48cbf7eaec5b79887003aeb8f60acaaf999f7ab9d290

See more details on using hashes here.

File details

Details for the file scikit_mdn-0.0.3-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for scikit_mdn-0.0.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 5985a2f3338f5deca8af78153ccb76a89a6c717f12b3b6d702e122e5a4ccaf41
MD5 f98f3418808591acdeef38e229f9e2a6
BLAKE2b-256 029b027d57e357eb887c48851894069870eaae08ec1830220bc0ade592bc9f58

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page