Samplers in MLX
Project description
samplex
Package of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.
Please get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!
Installation
pip install samplex
Basic Usage
For a full example, please see the examples folder. Here is the basic structure for linear regression:
from samplex.samplex import samplex
from samplex.samplers import MH_Gaussian_sampler
# First lets generate some data
x = mx.linspace(-5, 5, 20)
err = mx.random.normal(x.shape)
y = b_true * x**2 + m_true * x + c_true + err
# Our target distribution is just a line
def log_target_distribution(theta, data):
m, c, b = theta
x, y, sigma = data
model = b * x**2 + m * x + c
residual = y - model
return sum(-0.5 * (residual**2 / sigma**2))
# The sampler assumes it gets a target distribution with a single input vector theta
logtarget = lambda theta: log_target_distribution(theta, (x, y, err))
# Here are the sampler settings
Nwalkers = 32
Ndim = 3
Nsteps = 10_000
cov_matrix = mx.array([0.01, 0.01, 0.01])
jumping_factor = 1.0
theta0_array = mx.random.uniform(
mx.array([m_min, c_min, b_min]),
mx.array([m_max, c_max, b_max]),
(Nwalkers, Ndim),
)
# Firstly we instantiate a samplex class and then run!
sampler = MH_Gaussian_sampler(logtarget)
sam = samplex(sampler, Nwalkers)
sam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)
Next Steps:
- Get NUTs/HMC running
- Get Ensemble sampler running (emcee)
- Refine plotting
- Add helper functions for variety of priors
- Treating parameters with different update speeds
- Add file of priors and include in target distribution
- Include autocorrelation calculation for steps
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
samplex-0.0.2.tar.gz
(9.4 kB
view details)
Built Distribution
File details
Details for the file samplex-0.0.2.tar.gz
.
File metadata
- Download URL: samplex-0.0.2.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 903112a33dbcab71d3f1c37a3cb81a4dd023aa126f35e43ca78a9660e2e7281c |
|
MD5 | b95197f26491ec3eaca55dda997d66a1 |
|
BLAKE2b-256 | 1d6ea4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c |
File details
Details for the file samplex-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: samplex-0.0.2-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.12.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9247912cacd545bef373ab9f380dfec619da15d2323333bf66f6cfd6a7f48b4d |
|
MD5 | aaa555bb8852bf2bd5e2cc119cbf51b1 |
|
BLAKE2b-256 | 87e1d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6 |