Skip to main content

Bayesian Generalized linear modeling and GAMs utilizing NumPyro

Project description

Python code for (fast and accurate) Bayesian fitting of neural tuning curves using all types of regression models.

Overview

This package is intended to ease the fitting of Generalized Additive Models (GAM).

Long gone are the days for needing cross-validation, when what we really require is model condfidence and uncertainty estimates in our models.

This package aims to provide that in an easy to use format. The main hurdle for a user is clearly defining your design matrix.

The code and optimization procedures are all written in Pyro and NumPryo, allowing efficient usage of Probabilitic programming language techniques.

you have the options to use stochastic variational inference with different AutoGuides (Normal, multvariate, Laplace, Delta/MAP).

Current Models

Supports Poisson and Gaussian with natural hyperparameter tuning via two approaches: sthocastic variational inference or MCMC. Can you use any type of supplied tensor or regular basis function. Naturally implements wiggliness and null space coefficient constraints (a la L1). Naturally implements (laplacian) gaussian markov field regularization in both 1D and 2D. So, 2D auto regularizaiton for things like place or grid fields is testablre.

Example usage

if X = design matrix (n x m), and Y = data (n x 1)

import library and Instantiate a model

import GLM.glm as glm
import GLM.DesignMaker as dm
mod2fitall = glm.PoissonGLMbayes()

Add data to fit the model

mod2fitall.add_data(y=jnp.array(Y))

Build design matrices as marginal effects and tensors using patsy

basis_x_list, S_list, tensor_basis, tensor_S, beta_x_names = dm.pac_cont_dsgn_all_complex(X_train,
                                                                                          params={
                                                                                              'basismain': basistype,
                                                                                              'nbases': nbases,
                                                                                              'basistypeval': 'linear',
                                                                                              'nbasis': relval_bins,
                                                                                          'inter_nbases':5})

Define the model type, pass design bases and tensors, and call fit method

mod2fitall.define_model(model='prs_double_penalty', basis_x_list=basis_x_list, S_list=S_list,
                          tensor_basis_list=tensor_basis, S_tensor_list=tensor_S)

params = {'fittype': 'vi', 'guide': 'normal', 'visteps': 10000, 'optimtype': 'scheduled'}
mod2fitall.fit(params=params, beta_x_names=beta_x_names, fit_intercept=True, cauchy=3.0)

The model parameters posteriors can be acquired with different credibe interval levels

credible_interval=95
posterior_samples=5000
mod2fitall.sample_posterior(posterior_samples).summarize_posterior(credible_interval).coeff_relevance()

'significant' coefficients and paramters posterior paramters can be acquired as

posterior_mu_full= mod2fitall.posterior_means
posterior_sd_full = mod2fitall.posterior_sd
coefficients_sig = mod2fitall.coef_keep

Model Nomenclature

For the model argument in .define_model, there are a few types available.

  • 'prs_double_penalty': : implements a wiggliness parameter regularization and null-space parameter for basis funcitons (a la L1), and directly optimizes the smoothing hyperparamter

  • 'prs_hyperlambda' :implements a wiggliness parameter regulariation and directly optimizes the smoothing hyperparamter.

  • 'ardG_prs_mcmc': implements a wiggliness parameter regulariation and an automatric relevance determination prior over whole variables (not bases). Note: if linear (non-basis ) effects are used, then the whole variable is the coeff.


##Installation and dependency notes Make sure Pytorch, Jax, and Pyro are all installed in that order.


Please direct questions or bugs to justfineneuro@gmail.com, or submit an issue in the GitHub repo!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayesbrain-0.1.1.tar.gz (22.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayesbrain-0.1.1-py3-none-any.whl (23.9 kB view details)

Uploaded Python 3

File details

Details for the file bayesbrain-0.1.1.tar.gz.

File metadata

  • Download URL: bayesbrain-0.1.1.tar.gz
  • Upload date:
  • Size: 22.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for bayesbrain-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b44ed9f8905a2420121eb1c60f9575806724518a0ae983f87334c5fdd480bbe6
MD5 c54afc5e6c9211d192dd2ed081acbd7f
BLAKE2b-256 29d0b72fe355c3de20583bbe0275a8e07170d66e029ce81a8c5f53b8bd85a92b

See more details on using hashes here.

File details

Details for the file bayesbrain-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: bayesbrain-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 23.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for bayesbrain-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4fe49498ac3755f80874d6a7dd38de64f8d76e730521b15f1852d905d0cb8e36
MD5 aba342bbadb00cb5cc60726866704b3d
BLAKE2b-256 c816550d185912e55e909dafc9fdf975ae08970ce3c3c79bd4bd50b900cd6c2d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page