Samplers in MLX
Project description
samplex
Package of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.
Please get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!
Installation
pip install samplex
Basic Usage
For a full example, please see the examples folder. Here is the basic structure for linear regression:
from samplex.samplex import samplex
from samplex.samplers import MH_Gaussian_sampler
# First lets generate some data
x = mx.linspace(-5, 5, 20)
err = mx.random.normal(x.shape)
y = b_true * x**2 + m_true * x + c_true + err
# Our target distribution is just a line
def log_target_distribution(theta, data):
m, c, b = theta
x, y, sigma = data
model = b * x**2 + m * x + c
residual = y - model
return sum(-0.5 * (residual**2 / sigma**2))
# The sampler assumes it gets a target distribution with a single input vector theta
logtarget = lambda theta: log_target_distribution(theta, (x, y, err))
# Here are the sampler settings
Nwalkers = 32
Ndim = 3
Nsteps = 10_000
cov_matrix = mx.array([0.01, 0.01, 0.01])
jumping_factor = 1.0
theta0_array = mx.random.uniform(
mx.array([m_min, c_min, b_min]),
mx.array([m_max, c_max, b_max]),
(Nwalkers, Ndim),
)
# Firstly we instantiate a samplex class and then run!
sampler = MH_Gaussian_sampler(logtarget)
sam = samplex(sampler, Nwalkers)
sam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)
Next Steps:
- Get NUTs/HMC running
- Get Ensemble sampler running (emcee)
- Refine plotting
- Add helper functions for variety of priors
- Treating parameters with different update speeds
- Add file of priors and include in target distribution
- Include autocorrelation calculation for steps
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file samplex-0.0.2.tar.gz.
File metadata
- Download URL: samplex-0.0.2.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
903112a33dbcab71d3f1c37a3cb81a4dd023aa126f35e43ca78a9660e2e7281c
|
|
| MD5 |
b95197f26491ec3eaca55dda997d66a1
|
|
| BLAKE2b-256 |
1d6ea4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c
|
File details
Details for the file samplex-0.0.2-py3-none-any.whl.
File metadata
- Download URL: samplex-0.0.2-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9247912cacd545bef373ab9f380dfec619da15d2323333bf66f6cfd6a7f48b4d
|
|
| MD5 |
aaa555bb8852bf2bd5e2cc119cbf51b1
|
|
| BLAKE2b-256 |
87e1d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6
|