Skip to main content

The BASSA algorithm as presented in the paper Sparse Linear Bayesian Models for Organic Chemistry

Project description

BASSA: Bayesian Analysis with Spike-and-Slab Arrangements

Overview

Most chemical datasets are small and high-dimensional, making deep learning impractical. Linear regression remains interpretable and effective, but feature selection is critical. Traditional methods pick a single “best” model, overlooking the fact that multiple plausible models may exist.

BASSA combines Bayesian spike-and-slab regression with a filtering method to efficiently discover and organize many valid regression models. This reveals diverse interpretations hidden in chemical data without overcommitting to a single solution.


Installation

pip install bassa-reg

We also recommend installing LateX on your system to generate high-quality plots.
For Windows, you can use MiKTeX.
For MacOS using Homebrew:

brew install --cask mactex
brew install ghostscript

For Linux (Ubuntu/Debian):

sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended dvipng

Example Use

import os

import numpy as np
import pandas as pd
from bassa_reg import Bassa
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="demo")

regression.run()
bassa = Bassa(model=regression)
bassa.run()

Results

After running both the spike-and-slab regression and BASSA, results are saved in the specified project directory.
The main output is the bassa_plot.png file, which represents the models chosen by BASSA.

Alt text
This chart visualizes the different models found by BASSA, with their feature combinations and performance metrics.
Key additional outputs include:

Markov Chain Visualization

Alt text
The markov chain visualization shows the exploration of different models over iterations.
It is sorted by feature inclusion frequency, highlighting the most commonly selected features.
Precise feature inclusion frequencies are also provided in a separate file named feature_stats.csv.

Survival Process Plot

The survival plot, accompanied by the survival_table.csv file, illustrates the survival process of models over iterations.

Alt text
This is an auxiliary output that helps understand how models persist or change and is used to generate the upset chart.

Additional Data

The meta_data.csv file contains information about the Spike-and-Slab regression run, including the number of iterations, and other configuration details. It also includes some metrics about the regression performance on the training data.

Prediction on New Data

In order to make predictions on new data, create a new TestSet object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    x_test = pd.DataFrame(np.random.randn(int(N/2), M), columns=[fr's_{i}' for i in range(M)])
    return X, Y, x_test

x_train, y_train, x_test = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

test_set = TestSet(x_test=x_test,
                   samples_per_y=100,
                   iterations=200)

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    test_set=test_set,
                                    project_path=abs_dir,
                                    experiment_name="prediction_demo")

regression.run()

The sampler will run an extra numbers of iterations set by the iterations parameter in the TestSet object.
In every iteration, the sampler will sample samples_per_y values of y for each sample in the test set.
The average of these samples will be the predicted value for each sample in the test set.

Continuing a Previous Run

In order to continue a previous run, you first need to set save_samples=True on the SpikeAndSlabConfigurations object.
Then, you can load the previous run using the SpikeAndSlabLoader object and pass it to the SpikeAndSlabRegression object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, SpikeAndSlabLoader


def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 10, 6, noise_level=0.6)
priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000,
                                    save_meta_data=True,
                                    save_samples=True)
abs_dir = os.path.dirname(os.path.abspath(__file__))
regression = SpikeAndSlabRegression(x=x_train,
                            y=y_train,
                            priors=priors,
                            config=config,
                            project_path=abs_dir,
                            experiment_name="example_run")
regression.run()
loader = SpikeAndSlabLoader(path = f"{abs_dir}/{regression.full_experiment_name}")
regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="example_run",
                                    load_experiment=loader)
regression.run()

Choosing Priors For Spike-and-Slab

There are 3 latent variables in the spike-and-slab model that need priors:

BASSA Thresholds

TBD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bassa_reg-0.1.6.tar.gz (32.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bassa_reg-0.1.6-py3-none-any.whl (37.6 kB view details)

Uploaded Python 3

File details

Details for the file bassa_reg-0.1.6.tar.gz.

File metadata

  • Download URL: bassa_reg-0.1.6.tar.gz
  • Upload date:
  • Size: 32.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.6.tar.gz
Algorithm Hash digest
SHA256 1315d8aa5c7d727c4418837071162f5499b127316f028f90f4e3a5c76d21222f
MD5 4e7d4c8be2a7ad322f4a890d3a7535ec
BLAKE2b-256 b1d811be7ebc6624f91f74d67d693548eac3486133bdb1c199c3528919821f1b

See more details on using hashes here.

File details

Details for the file bassa_reg-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: bassa_reg-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 37.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 1948cf56504746dc8edf438e2462825fa2dcae63ebbb46b91a156797669b1667
MD5 a71bd9f7dbf5afbd930227652ce75c00
BLAKE2b-256 922c524c784141c9acae25a279b47711281b714a3712e1bf0acbe4d1083092bd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page