Skip to main content

Samplers in MLX

Project description

samplex Logo

samplex

Package of useful sampling algorithms written in MLX. We plan on exploring how a combination of unified memory (by exploiting GPU and CPU together) and auto-diff can be used to get highly efficient and robust sampling locally on your Mac.

Please get in touch if you're interested in contributing (tedwards2412@gmail.com and nash.sabti@gmail.com)!

Installation

pip install samplex

Basic Usage

For a full example, please see the examples folder. Here is the basic structure for linear regression:

from samplex.samplex import samplex
from samplex.samplers import MH_Gaussian_sampler

# First lets generate some data
x = mx.linspace(-5, 5, 20)
err = mx.random.normal(x.shape)
y = b_true * x**2 + m_true * x + c_true + err


# Our target distribution is just a line
def log_target_distribution(theta, data):
    m, c, b = theta
    x, y, sigma = data
    model = b * x**2 + m * x + c
    residual = y - model
    return sum(-0.5 * (residual**2 / sigma**2))

# The sampler assumes it gets a target distribution with a single input vector theta
logtarget = lambda theta: log_target_distribution(theta, (x, y, err))

# Here are the sampler settings
Nwalkers = 32
Ndim = 3
Nsteps = 10_000
cov_matrix = mx.array([0.01, 0.01, 0.01])
jumping_factor = 1.0

theta0_array = mx.random.uniform(
    mx.array([m_min, c_min, b_min]),
    mx.array([m_max, c_max, b_max]),
    (Nwalkers, Ndim),
)

# Firstly we instantiate a samplex class and then run!
sampler = MH_Gaussian_sampler(logtarget)
sam = samplex(sampler, Nwalkers)
sam.run(Nsteps, theta0_array, cov_matrix, jumping_factor)

Next Steps:

  • Get NUTs/HMC running
  • Get Ensemble sampler running (emcee)
  • Refine plotting
  • Add helper functions for variety of priors
  • Treating parameters with different update speeds
  • Add file of priors and include in target distribution
  • Include autocorrelation calculation for steps

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

samplex-0.0.2.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

samplex-0.0.2-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file samplex-0.0.2.tar.gz.

File metadata

  • Download URL: samplex-0.0.2.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.1

File hashes

Hashes for samplex-0.0.2.tar.gz
Algorithm Hash digest
SHA256 903112a33dbcab71d3f1c37a3cb81a4dd023aa126f35e43ca78a9660e2e7281c
MD5 b95197f26491ec3eaca55dda997d66a1
BLAKE2b-256 1d6ea4360305cae530cc41b83a2800ac47a0889118cd67d0978c0fd910513c2c

See more details on using hashes here.

File details

Details for the file samplex-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: samplex-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.12.1

File hashes

Hashes for samplex-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 9247912cacd545bef373ab9f380dfec619da15d2323333bf66f6cfd6a7f48b4d
MD5 aaa555bb8852bf2bd5e2cc119cbf51b1
BLAKE2b-256 87e1d2c426aea56ea9a170e4434b0c3a8cb05360e4ff54e4324ae983a705f8e6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page