Mixture Density Networks in scikit-learn
Project description
scikit-mdn
A mixture density network, by PyTorch, for scikit-learn
This project started as part of a live-stream that is part of the probabl outreach effort on YouTube. If you want to watch the relevant livestreams they can be found here and here.
Usage
To get this tool working locally you will first need to install it:
python -m pip install scikit-mdn
Then you can use it in your code. Here is a small demo example.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from skmdn import MixtureDensityEstimator
# Generate dataset
n_samples = 1000
X_full, _ = make_moons(n_samples=n_samples, noise=0.1)
X = X_full[:, 0].reshape(-1, 1) # Use only the first column as input
Y = X_full[:, 1].reshape(-1, 1) # Predict the second column
# Add some noise to Y to make the problem more suitable for MDN
Y += 0.1 * np.random.randn(n_samples, 1)
# Fit the model
mdn = MixtureDensityEstimator()
mdn.fit(X, Y)
# Predict some quantiles on the train set
means, quantiles = mdn.predict(X, quantiles=[0.01, 0.1, 0.9, 0.99], resolution=100000)
plt.scatter(X, Y)
plt.scatter(X, quantiles[:, 0], color='orange')
plt.scatter(X, quantiles[:, 1], color='green')
plt.scatter(X, quantiles[:, 2], color='green')
plt.scatter(X, quantiles[:, 3], color='orange')
plt.scatter(X, means, color='red')
This is what the chart looks like:
API Documentation
You can find the API documentation on GitHub pages, found here:
https://koaning.github.io/scikit-mdn/
More depth
If you appreciate a glimpse of the internals, you may want to play around with the mdn.ipynb notebook that contains a Jupyter widget.
Extra resources
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file scikit_mdn-0.0.3.tar.gz.
File metadata
- Download URL: scikit_mdn-0.0.3.tar.gz
- Upload date:
- Size: 5.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9872cac775cdfe11862458814f32cf0dc474ca2dc2b02a65a6925e5335be2990
|
|
| MD5 |
f4a6fe80728d8473aec8ab0268c1904a
|
|
| BLAKE2b-256 |
48c56576b17e402a12ba53ad46b81b87e92826dc3c894c1fe8c852cef5f06a31
|
File details
Details for the file scikit_mdn-0.0.3-py3-none-any.whl.
File metadata
- Download URL: scikit_mdn-0.0.3-py3-none-any.whl
- Upload date:
- Size: 5.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cc96f9878e47e1d510e07f8476546ab240efa65f014f033950f3994e1c292de0
|
|
| MD5 |
6ed9685355fa1290b4db789488a88d5e
|
|
| BLAKE2b-256 |
837006d9f57e24c309cb48cbf7eaec5b79887003aeb8f60acaaf999f7ab9d290
|
File details
Details for the file scikit_mdn-0.0.3-py2.py3-none-any.whl.
File metadata
- Download URL: scikit_mdn-0.0.3-py2.py3-none-any.whl
- Upload date:
- Size: 5.1 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5985a2f3338f5deca8af78153ccb76a89a6c717f12b3b6d702e122e5a4ccaf41
|
|
| MD5 |
f98f3418808591acdeef38e229f9e2a6
|
|
| BLAKE2b-256 |
029b027d57e357eb887c48851894069870eaae08ec1830220bc0ade592bc9f58
|