Skip to main content

The BASSA algorithm as presented in the paper Sparse Linear Bayesian Models for Organic Chemistry

Project description

BASSA: Bayesian Analysis with Spike-and-Slab Arrangements

Overview

Most chemical datasets are small and high-dimensional, making deep learning impractical. Linear regression remains interpretable and effective, but feature selection is critical. Traditional methods pick a single “best” model, overlooking the fact that multiple plausible models may exist.

BASSA combines Bayesian spike-and-slab regression with a filtering method to efficiently discover and organize many valid regression models. This reveals diverse interpretations hidden in chemical data without overcommitting to a single solution.


Installation

pip install bassa-reg

We also recommend installing LateX on your system to generate high-quality plots.
For Windows, you can use MiKTeX.
For MacOS using Homebrew:

brew install --cask mactex
brew install ghostscript

For Linux (Ubuntu/Debian):

sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended dvipng

Example Use

import os

import numpy as np
import pandas as pd
from bassa_reg import Bassa
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="demo")

regression.run()
bassa = Bassa(model=regression)
bassa.run()

Results

After running both the spike-and-slab regression and BASSA, results are saved in the specified project directory.
The main output is the bassa_plot.png file, which represents the models chosen by BASSA.

Alt text
This chart visualizes the different models found by BASSA, with their feature combinations and performance metrics.
Key additional outputs include:

Markov Chain Visualization

Alt text
The markov chain visualization shows the exploration of different models over iterations.
It is sorted by feature inclusion frequency, highlighting the most commonly selected features.
Precise feature inclusion frequencies are also provided in a separate file named feature_stats.csv.

Survival Process Plot

The survival plot, accompanied by the survival_table.csv file, illustrates the survival process of models over iterations.

Alt text
This is an auxiliary output that helps understand how models persist or change and is used to generate the upset chart.

Additional Data

The meta_data.csv file contains information about the Spike-and-Slab regression run, including the number of iterations, and other configuration details. It also includes some metrics about the regression performance on the training data.

Prediction on New Data

In order to make predictions on new data, create a new TestSet object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    x_test = pd.DataFrame(np.random.randn(int(N/2), M), columns=[fr's_{i}' for i in range(M)])
    return X, Y, x_test

x_train, y_train, x_test = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

test_set = TestSet(x_test=x_test,
                   samples_per_y=100,
                   iterations=200)

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    test_set=test_set,
                                    project_path=abs_dir,
                                    experiment_name="prediction_demo")

regression.run()

The sampler will run an extra numbers of iterations set by the iterations parameter in the TestSet object.
In every iteration, the sampler will sample samples_per_y values of y for each sample in the test set.
The average of these samples will be the predicted value for each sample in the test set.

Continuing a Previous Run

In order to continue a previous run, you first need to set save_samples=True on the SpikeAndSlabConfigurations object.
Then, you can load the previous run using the SpikeAndSlabLoader object and pass it to the SpikeAndSlabRegression object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, SpikeAndSlabLoader


def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 10, 6, noise_level=0.6)
priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000,
                                    save_meta_data=True,
                                    save_samples=True)
abs_dir = os.path.dirname(os.path.abspath(__file__))
regression = SpikeAndSlabRegression(x=x_train,
                            y=y_train,
                            priors=priors,
                            config=config,
                            project_path=abs_dir,
                            experiment_name="example_run")
regression.run()
loader = SpikeAndSlabLoader(path = f"{abs_dir}/{regression.full_experiment_name}")
regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="example_run",
                                    load_experiment=loader)
regression.run()

Choosing Priors For Spike-and-Slab

There are 3 latent variables in the spike-and-slab model that need priors:

BASSA Thresholds

TBD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bassa_reg-0.1.5.tar.gz (32.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bassa_reg-0.1.5-py3-none-any.whl (37.6 kB view details)

Uploaded Python 3

File details

Details for the file bassa_reg-0.1.5.tar.gz.

File metadata

  • Download URL: bassa_reg-0.1.5.tar.gz
  • Upload date:
  • Size: 32.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.5.tar.gz
Algorithm Hash digest
SHA256 bbfb02c6dcb4cc62bebdd0e9f5c88e9e263a8f3ba998fb97ded878dcadca46eb
MD5 453e3f81ae7d994502348b7a99c2fc7b
BLAKE2b-256 0fc3e21a8744981bdac8579c8f4d0118c54d6e673d37632c768a3699a31fc9e4

See more details on using hashes here.

File details

Details for the file bassa_reg-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: bassa_reg-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 37.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 ea9609e94e4994692f9953aee9d3e32b3bf37cc710c3563422fa8e0534a87cec
MD5 920fceba566948fef00fc024222adcd2
BLAKE2b-256 b66dbf070f696c735451870f485b09762c92908f619d3041aa9d405d290af71e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page