Skip to main content

Pretrained Lasso: a two-step procedure for sparse linear models with grouped samples

Project description

ptlasso

PyPI - Version versions Downloads Downloads

Python implementation of the Pretrained Lasso — a two-step procedure for fitting sparse linear models when samples belong to distinct groups, leveraging shared structure across groups via pretraining.

Based on:

Craig, E., Pilanci, M., Le Menestrel, T., Narasimhan, B., Rivas, M. A., Gullaksen, S. E., ... & Tibshirani, R. (2025). Pretraining and the lasso. Journal of the Royal Statistical Society Series B: Statistical Methodology, qkaf050.


The idea

Standard group-specific Lasso models are fit independently per group, ignoring shared signal. The Pretrained Lasso fits in two steps:

Step 1 — Overall model. Fit a Lasso on all samples to capture shared structure:

$$\hat{\beta}^{\text{overall}} = \arg\min_\beta \frac{1}{2n}|y - X\beta|^2 + \lambda|\beta|_1$$

Step 2 — Group models. For each group $k$, fit a Lasso with an offset equal to $\alpha$ times the overall model's linear predictor:

$$\hat{\beta}^{(k)} = \arg\min_\beta \frac{1}{2n_k}|y^{(k)} - \underbrace{\alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}}}_{\text{offset}} - X^{(k)}\beta|^2 + \lambda_k|\beta|_1$$

The parameter $\alpha \in [0, 1]$ controls the pretraining strength:

  • $\alpha = 0$: pure group-specific models, no pretraining
  • $\alpha = 1$: group models explain residuals from the overall model
  • $\alpha \in (0, 1)$: group models are anchored to the overall fit

Final prediction for group $k$: $\hat{y}^{(k)} = \alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}} + X^{(k)}\hat{\beta}^{(k)}$

Supports gaussian, binomial, and multinomial families.


Installation

pip install ptlasso

Requires Python ≥ 3.9 and adelie for the underlying Lasso solver, which supports fitting with offsets (unlike scikit-learn).


Quick start

import numpy as np
from ptlasso import PretrainedLasso, PretrainedLassoCV

rng = np.random.default_rng(42)
n, p, k = 300, 100, 3

X      = rng.standard_normal((n, p))
groups = rng.integers(0, k, size=n)
beta   = np.zeros(p)
beta[:5] = [2, -1.5, 1, -0.8, 0.5]
y      = X @ beta + 0.5 * rng.standard_normal(n)

# Fixed alpha
model = PretrainedLasso(alpha=0.5)
model.fit(X, y, groups)
print(model)
# PretrainedLasso(alpha=0.5, family='gaussian', overall_lambda='lambda.1se', ...)
#   family       : gaussian
#   n_features   : 100
#   n_groups     : 3
#   overall |Ŝ|  : |Ŝ| = 5 / 100  [0, 1, 2, 3, 4]
#   pretrain |Ŝ| : 0: |Ŝ|=5, 1: |Ŝ|=4, 2: |Ŝ|=6

y_pred = model.predict(X, groups)
print("R²:", model.score(X, y, groups))

# Cross-validate over alpha
cv = PretrainedLassoCV(alphas=[0.0, 0.25, 0.5, 0.75, 1.0])
cv.fit(X, y, groups)
print("Best alpha:", cv.alpha_)

Families

# Binary classification
model = PretrainedLasso(alpha=0.5, family="binomial")
model.fit(X, y_binary, groups)
probs = model.predict(X, groups)          # shape (n,), P(y=1)

# Multi-class classification (integer labels 0..K-1)
model = PretrainedLasso(alpha=0.5, family="multinomial")
model.fit(X, y_multiclass, groups)
probs = model.predict(X, groups)          # shape (n, K)

Feature names and group labels

Both fit() methods accept human-readable names. pandas DataFrames are supported natively — column names are picked up automatically.

import pandas as pd

X_df         = pd.DataFrame(X, columns=[f"gene_{i}" for i in range(p)])
group_labels = {0: "control", 1: "treated_A", 2: "treated_B"}

model = PretrainedLasso(alpha=0.5)
model.fit(X_df, y, groups, group_labels=group_labels)
# overall |Ŝ|  : |Ŝ| = 5 / 100  [gene_0, gene_1, gene_2, gene_3, gene_4]
# pretrain |Ŝ| : control: |Ŝ|=5, treated_A: |Ŝ|=4, treated_B: |Ŝ|=6

Inspecting the support

The support is the set of non-zero variables selected by the model, i.e., the features whose coefficients are not shrunk to zero by the L1 regularization and therefore actively contribute to the prediction.

from ptlasso import (
    get_overall_support,
    get_pretrain_support,
    get_pretrain_support_split,
    get_individual_support,
)

get_overall_support(model)                      # features from the overall model
get_pretrain_support(model)                     # union across pretrained group models
get_pretrain_support(model, common_only=True)   # features selected by >50% of groups
get_pretrain_support(model, groups=[0, 1])      # restrict to specific groups
get_individual_support(model)                   # features from no-pretraining baselines

common, indiv = get_pretrain_support_split(model)
# common : features from the overall model (stage 1)
# indiv  : additional features picked up by group models (stage 2)

Evaluating all sub-models at once

result = model.evaluate(X_test, y_test, groups_test)
# {"pretrain":   {"predictions": ..., "score": ...},
#  "individual": {"predictions": ..., "score": ...},
#  "overall":    {"predictions": ..., "score": ...}}

Retrieving coefficients

coefs = model.get_coef()                   # all sub-models
coefs["overall"]                           # {"coef": ndarray, "intercept": ndarray}
coefs["pretrain"]["control"]               # {"coef": ndarray, "intercept": ndarray}
coefs["individual"]["treated_A"]

model.get_coef(model="pretrain")           # just pretrain sub-dict

CV details

cv = PretrainedLassoCV(
    alphas=[0.0, 0.25, 0.5, 0.75, 1.0],
    cv=5,
    alphahat_choice="overall",   # or "mean" (unweighted mean of per-group CV errors)
    family="gaussian",
    overall_lambda="lambda.1se", # or "lambda.min"
    foldid=my_foldid,            # optional: custom integer fold assignments
)
cv.fit(X, y, groups)

cv.alpha_                        # globally best alpha
cv.varying_alphahat_             # {group: best_alpha} per group
cv.cv_results_                   # {alpha: mean CV loss}
cv.cv_results_se_                # {alpha: SE of CV loss}
cv.cv_results_per_group_         # {alpha: {group: mean CV loss}}
cv.cv_results_mean_              # {alpha: unweighted mean of per-group losses}
cv.cv_results_wtd_mean_          # {alpha: size-weighted mean of per-group losses}
cv.cv_results_individual_        # CV loss for individual (no-pretraining) baseline
cv.cv_results_overall_           # CV loss for overall model baseline
cv.best_estimator_               # PretrainedLasso fitted with alpha_
cv.all_estimators_               # {alpha: PretrainedLasso} for varying-alpha prediction

# Predict using each group's own best alpha
cv.predict(X, groups, alphatype="varying")
cv.evaluate(X, y, groups, alphatype="varying")

Plotting

from ptlasso import plot_cv, plot_paths

plot_cv(cv)           # CV loss curve over alpha with ±1 SE band
plot_paths(model)     # regularisation paths for all sub-models

Saving and loading models

PretrainedLasso and PretrainedLassoCV can be serialised with joblib:

import joblib

# Save
joblib.dump(model, "model.pkl")

# Load
model = joblib.load("model.pkl")

Note: serialised models are tied to the Python and library versions used to create them. For long-term storage, pin your dependencies or retrain from scratch after major version upgrades.

Avoid pickle directly — the underlying adelie solver stores C++ objects that are not natively picklable. ptlasso handles this transparently through joblib by converting solver state to plain numpy arrays before serialisation. Using pickle directly bypasses this and will raise a TypeError.


API reference

PretrainedLasso

Parameter Default Description
alpha 0.5 Pretraining strength $\in [0, 1]$
family "gaussian" "gaussian", "binomial", or "multinomial"
overall_lambda "lambda.1se" Lambda rule for stage-1 offset: "lambda.1se" or "lambda.min"
fit_intercept True Fit an intercept in all sub-models
lmda_path_size 100 Number of $\lambda$ values in the regularisation path
min_ratio 0.01 Ratio of smallest to largest $\lambda$
verbose False Show adelie progress bar

Methods:

  • fit(X, y, groups, group_labels=None, feature_names=None)
  • predict(X, groups, model="pretrain", lmda_idx=None)model{"pretrain", "individual", "overall"}
  • score(X, y, groups) — R² or accuracy
  • evaluate(X, y, groups) — predict + score for all three sub-models
  • get_coef(model="all", lmda_idx=None)

PretrainedLassoCV

Parameter Default Description
alphas [0, 0.25, 0.5, 0.75, 1.0] Candidate $\alpha$ values
cv 5 Number of CV folds
alphahat_choice "overall" "overall" or "mean" (unweighted per-group mean)
family "gaussian" Same as PretrainedLasso
overall_lambda "lambda.1se" Same as PretrainedLasso
fit_intercept True
lmda_path_size 100
min_ratio 0.01
verbose False
foldid None Integer array of fold assignments (overrides cv)

Same fit / predict / score / evaluate / get_coef interface as PretrainedLasso, plus:

Fitted attribute Description
alpha_ Best $\alpha$ selected by CV
varying_alphahat_ {group: alpha} — per-group best $\alpha$
cv_results_ {alpha: mean CV loss}
cv_results_se_ {alpha: SE of CV loss}
cv_results_per_group_ {alpha: {group: mean CV loss}}
cv_results_mean_ {alpha: unweighted mean of per-group losses}
cv_results_wtd_mean_ {alpha: size-weighted mean of per-group losses}
cv_results_individual_ CV loss for individual baseline
cv_results_overall_ CV loss for overall baseline
best_estimator_ PretrainedLasso fitted with alpha_
all_estimators_ {alpha: PretrainedLasso} for each unique varying alpha

predict also accepts alphatype="varying" to route each group through its own best alpha.


Citation

@article{craig2025pretraining,
  title   = {Pretraining and the lasso},
  author  = {Craig, Erin and Pilanci, Mert and Le Menestrel, Thomas and Narasimhan, Balasubramanian and Rivas, Manuel A. and Gullaksen, Stein-Erik and Tibshirani, Robert},
  journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
  pages   = {qkaf050},
  year    = {2025}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ptlasso-0.3.2.tar.gz (32.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ptlasso-0.3.2-py3-none-any.whl (30.8 kB view details)

Uploaded Python 3

File details

Details for the file ptlasso-0.3.2.tar.gz.

File metadata

  • Download URL: ptlasso-0.3.2.tar.gz
  • Upload date:
  • Size: 32.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for ptlasso-0.3.2.tar.gz
Algorithm Hash digest
SHA256 93045e314148b131e0f4f9c2f75afa7bd0b2d451bd834d2c064d49c17e672879
MD5 0110d52a171c62fd41b88dc045509a8c
BLAKE2b-256 49469732f4ecbd672af018df85d119c07af6e982c733b349a8b7a9ecfea2bf28

See more details on using hashes here.

File details

Details for the file ptlasso-0.3.2-py3-none-any.whl.

File metadata

  • Download URL: ptlasso-0.3.2-py3-none-any.whl
  • Upload date:
  • Size: 30.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for ptlasso-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a83b0925c894e459271eed3f5677ef437adf311f4eab9333f14d26e78f28aa0a
MD5 d261364d5ba7f6411cc2e0a1b2c8ef3c
BLAKE2b-256 4d31defe43626a26e254698686648269716acea9d2b82a6dd0461afca96a2907

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page