Skip to main content

The BASSA algorithm as presented in the paper Sparse Linear Bayesian Models for Organic Chemistry

Project description

BASSA: Bayesian Analysis with Spike-and-Slab Arrangements

Overview

Most chemical datasets are small and high-dimensional, making deep learning impractical. Linear regression remains interpretable and effective, but feature selection is critical. Traditional methods pick a single “best” model, overlooking the fact that multiple plausible models may exist.

BASSA combines Bayesian spike-and-slab regression with a filtering method to efficiently discover and organize many valid regression models. This reveals diverse interpretations hidden in chemical data without overcommitting to a single solution.


Installation

pip install bassa-reg

We also recommend installing LateX on your system to generate high-quality plots.
For Windows, you can use MiKTeX.
For MacOS using Homebrew:

brew install --cask mactex
brew install ghostscript

For Linux (Ubuntu/Debian):

sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended dvipng

Example Use

import os

import numpy as np
import pandas as pd
from bassa_reg import Bassa
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="demo")

regression.run()
bassa = Bassa(model=regression)
bassa.run()

Results

After running both the spike-and-slab regression and BASSA, results are saved in the specified project directory.
The main output is the bassa_plot.png file, which represents the models chosen by BASSA.

Alt text
This chart visualizes the different models found by BASSA, with their feature combinations and performance metrics.
Key additional outputs include:

Markov Chain Visualization

Alt text
The markov chain visualization shows the exploration of different models over iterations.
It is sorted by feature inclusion frequency, highlighting the most commonly selected features.
Precise feature inclusion frequencies are also provided in a separate file named feature_stats.csv.

Survival Process Plot

The survival plot, accompanied by the survival_table.csv file, illustrates the survival process of models over iterations.

Alt text
This is an auxiliary output that helps understand how models persist or change and is used to generate the upset chart.

Additional Data

The meta_data.csv file contains information about the Spike-and-Slab regression run, including the number of iterations, and other configuration details. It also includes some metrics about the regression performance on the training data.

Prediction on New Data

In order to make predictions on new data, create a new TestSet object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    x_test = pd.DataFrame(np.random.randn(int(N/2), M), columns=[fr's_{i}' for i in range(M)])
    return X, Y, x_test

x_train, y_train, x_test = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

test_set = TestSet(x_test=x_test,
                   samples_per_y=100,
                   iterations=200)

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    test_set=test_set,
                                    project_path=abs_dir,
                                    experiment_name="prediction_demo")

regression.run()

The sampler will run an extra numbers of iterations set by the iterations parameter in the TestSet object.
In every iteration, the sampler will sample samples_per_y values of y for each sample in the test set.
The average of these samples will be the predicted value for each sample in the test set.

Continuing a Previous Run

In order to continue a previous run, you first need to set save_samples=True on the SpikeAndSlabConfigurations object.
Then, you can load the previous run using the SpikeAndSlabLoader object and pass it to the SpikeAndSlabRegression object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, SpikeAndSlabLoader


def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 10, 6, noise_level=0.6)
priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000,
                                    save_meta_data=True,
                                    save_samples=True)
abs_dir = os.path.dirname(os.path.abspath(__file__))
regression = SpikeAndSlabRegression(x=x_train,
                            y=y_train,
                            priors=priors,
                            config=config,
                            project_path=abs_dir,
                            experiment_name="example_run")
regression.run()
loader = SpikeAndSlabLoader(path = f"{abs_dir}/{regression.full_experiment_name}")
regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="example_run",
                                    load_experiment=loader)
regression.run()

Choosing Priors For Spike-and-Slab

There are 3 latent variables in the spike-and-slab model that need priors:

BASSA Thresholds

TBD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bassa_reg-0.1.8.tar.gz (41.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bassa_reg-0.1.8-py3-none-any.whl (46.7 kB view details)

Uploaded Python 3

File details

Details for the file bassa_reg-0.1.8.tar.gz.

File metadata

  • Download URL: bassa_reg-0.1.8.tar.gz
  • Upload date:
  • Size: 41.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.8.tar.gz
Algorithm Hash digest
SHA256 d8e6c7998452120b3653e4b38cb481995f4f0379b82e55fdd533c5d71afad01b
MD5 9922276d32283ef97088ba24524cad2c
BLAKE2b-256 7c03a01254a9240f6c93054a2de684e4539648596e93e1e2e46f9639c4219490

See more details on using hashes here.

File details

Details for the file bassa_reg-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: bassa_reg-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 46.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 a2ad5c151792023b8685756eb74d8fb3a7454330014b381d5143afca7e9816f5
MD5 cfce7ebd7f999ec8b9fce02c15e1a58f
BLAKE2b-256 e1739e7d007d6c6fe149936b3a94a0e9032825b827667ff5d6af95153788637f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page