Skip to main content

The BASSA algorithm as presented in the paper Sparse Linear Bayesian Models for Organic Chemistry

Project description

BASSA: Bayesian Analysis with Spike-and-Slab Arrangements

Overview

Most chemical datasets are small and high-dimensional, making deep learning impractical. Linear regression remains interpretable and effective, but feature selection is critical. Traditional methods pick a single “best” model, overlooking the fact that multiple plausible models may exist.

BASSA combines Bayesian spike-and-slab regression with a filtering method to efficiently discover and organize many valid regression models. This reveals diverse interpretations hidden in chemical data without overcommitting to a single solution.


Installation

pip install bassa-reg

We also recommend installing LateX on your system to generate high-quality plots.
For Windows, you can use MiKTeX.
For MacOS using Homebrew:

brew install --cask mactex
brew install ghostscript

For Linux (Ubuntu/Debian):

sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended dvipng

Example Use

import os

import numpy as np
import pandas as pd
from bassa_reg import Bassa
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="demo")

regression.run()
bassa = Bassa(model=regression)
bassa.run()

Results

After running both the spike-and-slab regression and BASSA, results are saved in the specified project directory.
The main output is the bassa_plot.png file, which represents the models chosen by BASSA.

Alt text
This chart visualizes the different models found by BASSA, with their feature combinations and performance metrics.
Key additional outputs include:

Markov Chain Visualization

Alt text
The markov chain visualization shows the exploration of different models over iterations.
It is sorted by feature inclusion frequency, highlighting the most commonly selected features.
Precise feature inclusion frequencies are also provided in a separate file named feature_stats.csv.

Survival Process Plot

The survival plot, accompanied by the survival_table.csv file, illustrates the survival process of models over iterations.

Alt text
This is an auxiliary output that helps understand how models persist or change and is used to generate the upset chart.

Additional Data

The meta_data.csv file contains information about the Spike-and-Slab regression run, including the number of iterations, and other configuration details. It also includes some metrics about the regression performance on the training data.

Prediction on New Data

In order to make predictions on new data, create a new TestSet object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, TestSet

def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    x_test = pd.DataFrame(np.random.randn(int(N/2), M), columns=[fr's_{i}' for i in range(M)])
    return X, Y, x_test

x_train, y_train, x_test = generate_data(100, 20, 5)

priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000)
abs_dir = os.path.dirname(os.path.abspath(__file__))

test_set = TestSet(x_test=x_test,
                   samples_per_y=100,
                   iterations=200)

regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    test_set=test_set,
                                    project_path=abs_dir,
                                    experiment_name="prediction_demo")

regression.run()

The sampler will run an extra numbers of iterations set by the iterations parameter in the TestSet object.
In every iteration, the sampler will sample samples_per_y values of y for each sample in the test set.
The average of these samples will be the predicted value for each sample in the test set.

Continuing a Previous Run

In order to continue a previous run, you first need to set save_samples=True on the SpikeAndSlabConfigurations object.
Then, you can load the previous run using the SpikeAndSlabLoader object and pass it to the SpikeAndSlabRegression object.

import os

import numpy as np
import pandas as pd
from bassa_reg.spike_and_slab.spike_and_slab import SpikeAndSlabConfigurations, SpikeAndSlabRegression
from bassa_reg.spike_and_slab.spike_and_slab_util_models import SpikeAndSlabPriors, SpikeAndSlabLoader


def generate_data(N, M, K, noise_level=0.1):
    X = pd.DataFrame(np.random.randn(N, M), columns=[fr's_{i}' for i in range(M)])
    coefficients = np.random.randn(K)
    Y = pd.Series(X.iloc[:, :K].dot(coefficients) + np.random.randn(N) * noise_level)
    return X, Y

x_train, y_train = generate_data(100, 10, 6, noise_level=0.6)
priors = SpikeAndSlabPriors()
config = SpikeAndSlabConfigurations(sampler_iterations=5000,
                                    save_meta_data=True,
                                    save_samples=True)
abs_dir = os.path.dirname(os.path.abspath(__file__))
regression = SpikeAndSlabRegression(x=x_train,
                            y=y_train,
                            priors=priors,
                            config=config,
                            project_path=abs_dir,
                            experiment_name="example_run")
regression.run()
loader = SpikeAndSlabLoader(path = f"{abs_dir}/{regression.full_experiment_name}")
regression = SpikeAndSlabRegression(x=x_train,
                                    y=y_train,
                                    priors=priors,
                                    config=config,
                                    project_path=abs_dir,
                                    experiment_name="example_run",
                                    load_experiment=loader)
regression.run()

Choosing Priors For Spike-and-Slab

There are 3 latent variables in the spike-and-slab model that need priors:

BASSA Thresholds

TBD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bassa_reg-0.1.9.tar.gz (42.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bassa_reg-0.1.9-py3-none-any.whl (47.6 kB view details)

Uploaded Python 3

File details

Details for the file bassa_reg-0.1.9.tar.gz.

File metadata

  • Download URL: bassa_reg-0.1.9.tar.gz
  • Upload date:
  • Size: 42.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.9.tar.gz
Algorithm Hash digest
SHA256 819f03401b70701c9ccce9b4de66a03081e210364db0b9e94af8fcc767cb6081
MD5 80af329bc7636fab62c9d630c8f10788
BLAKE2b-256 6466fc0a5a289dfcdecdbe5269f0a9b3cec4f2b1c3694a57db33ecb36119c316

See more details on using hashes here.

File details

Details for the file bassa_reg-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: bassa_reg-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 47.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for bassa_reg-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 98bbf7175bec76543a1e980736a6cfd0ce65a0816015f8f248434f656d2b233a
MD5 c82ff0b4acd89cc786475d9ab20487bb
BLAKE2b-256 c25ce4a494bf764754f49fb489a20782d5c8f09e14f3af65253288c0236931e8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page