Skip to main content

A logistic regression model with missing data handling using the SAEM algorithm.

Project description

misaem: Logistic Regression with Missing Covariates

misaem is a Python package that implements logistic regression for data with missing covariates. It is based on the Stochastic Approximation of the Expectation-Maximization (SAEM) algorithm. The algorithm relies on Metropolis-Hastings sampling within the MCMC step to handle the imputation of missing data, under the assumption that the covariates follow a multi-variate normal distribution.

This package provides a robust method for parameter estimation and prediction under Missing Completely at Random (MCAR) and Missing At Random (MAR) assumptions. It is a direct port of the misaem R package and is designed to integrate seamlessly with the scikit-learn ecosystem, adhering to its API conventions for ease of use.

🌟 Key Features

  • Robust Handling of Missing Data: Employs the SAEM algorithm to handle missing covariate data during model fitting and prediction.

  • Predictive Power: Can be used to make predictions on new, incomplete datasets.

  • scikit-learn Compatibility: Follows scikit-learn's API, allowing for direct use in pipelines, cross-validation, and other standard workflows.

  • Statistical Foundation: Provides estimated coefficients and their standard errors, enabling statistical inference.

💻 Installation

You can install the package via pip:

pip install misaem

🚀 Getting Started

Here is a full example demonstrating how to use SAEMLogisticRegression to fit a model, make predictions, and visualize the results.

import numpy as np
from misaem import SAEMLogisticRegression
import matplotlib.pyplot as plt


# 1. Generate Data
np.random.seed(1324)
n_train = 1000
n_test = 300
n_features = 5

cov = [[1,0.75,0.5,0.25,0],
       [0.75,1,0.75,0.5,0.25],
       [0.5,0.75,1,0.75,0.5],
       [0.25,0.5,0.75,1,0.75],
       [0,0.25,0.5,0.75,1]]

X = np.random.multivariate_normal(mean=np.zeros(n_features), cov=cov, size=n_train + n_test)
true_beta = np.hstack([0.5, np.random.normal(size=n_features)])
linear_pred = np.hstack([np.ones((n_train + n_test, 1)), X]) @ true_beta
probabilities = 1 / (1 + np.exp(-linear_pred))
y = np.random.binomial(1, probabilities)

X_missing = X.copy()
missing_mask = np.random.rand(n_train + n_test, n_features) < 0.2
X_missing[missing_mask] = np.nan

X_train = X_missing[:n_train,:]
X_test = X_missing[(n_train+1):]
y_train = y[:n_train]
y_test = y[(n_train+1):]

# 2. Fit SAEMLogisticRegression
model = SAEMLogisticRegression()
model.fit(X_train, y_train)

# 3. Predict on test set
preds = model.predict_proba(X_test, method="map")[:,1]

# 4. Evaluate
accuracy = np.mean((preds > 0.5) == y_test)
print(f"Accuracy of SAEMLogisticRegression: {accuracy:.4f}")

coef = model.coef_.ravel()
se = model.std_err_.ravel()[1:] # remove the intercept

plt.figure(figsize=(8, 4)) 
plt.errorbar(range(len(coef)), coef, yerr=1.96*se, fmt='o', label='Estimated Coef with 95% CI')
plt.xticks(range(len(coef)), [f'X{i}' for i in range(len(coef))])
plt.plot(range(len(coef)), true_beta[1:], 'o', color='red', label='True Coef')
plt.legend()
plt.show()

saem coefficients

⚙️ scikit-learn Integration

The misaem package is compatible with scikit-learn's API, which allows you to seamlessly integrate it into your machine learning pipelines. For instance, you can combine it with data preprocessing steps like StandardScaler.

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from misaem import SAEMLogisticRegression

pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('model', SAEMLogisticRegression(random_state=42))
])

pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)

accuracy_pipeline = np.mean(y_pred == y_test)
print(f"Accuracy of SAEMLogisticRegression: {accuracy_pipeline:.4f}")

sklearn pipeline

📚 Reference

This package is based on the SAEM algorithm for logistic regression with missing covariates. For more details on the methodology, please refer to the following publication:

Jiang, W., Josse, J., Lavielle, M., & TraumaBase Group. (2020). Logistic regression with missing covariates—Parameter estimation, model selection and prediction within a joint-modeling framework. Computational Statistics & Data Analysis, 145, 106907.

🤝 Contributing

We welcome contributions! If you encounter any bugs, have feature requests, or want to contribute to the code, please open an issue or a pull request on our GitHub repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

misaem-1.0.0.tar.gz (44.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

misaem-1.0.0-py3-none-any.whl (10.7 kB view details)

Uploaded Python 3

File details

Details for the file misaem-1.0.0.tar.gz.

File metadata

  • Download URL: misaem-1.0.0.tar.gz
  • Upload date:
  • Size: 44.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for misaem-1.0.0.tar.gz
Algorithm Hash digest
SHA256 3d9130fbe8fdb61ef70723833c0f01ac277cb96a4062e3bfeaef1c1ef624ebbf
MD5 424165bac6a33a059a3d558ee37a757a
BLAKE2b-256 e6431c88eefe3c919f025afd64e7748face7070fe3d1d2d42dadf940511335a1

See more details on using hashes here.

File details

Details for the file misaem-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: misaem-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 10.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for misaem-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 61edf45ade5a5c8021221889ae232a7fb520714e3006a5b4c111b6d7afd7068f
MD5 a7dc3a20d81b23b4a66e114d52942a95
BLAKE2b-256 ec0d731e18c473cd13a281aa58a12aada6f1ace491f1f38db7e480cc67e6bdb7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page