Skip to main content

Pretrained Lasso: a two-step procedure for sparse linear models with grouped samples

Project description

ptlasso

PyPI - Version versions Downloads Downloads

Python implementation of the Pretrained Lasso — a two-step procedure for fitting sparse linear models when samples belong to distinct groups, leveraging shared structure across groups via pretraining.

Based on:

Craig, E., Pilanci, M., Le Menestrel, T., Narasimhan, B., Rivas, M. A., Gullaksen, S. E., ... & Tibshirani, R. (2025). Pretraining and the lasso. Journal of the Royal Statistical Society Series B: Statistical Methodology, qkaf050.


The idea

Standard group-specific Lasso models are fit independently per group, ignoring shared signal. The Pretrained Lasso fits in two steps:

Step 1 — Overall model. Fit a Lasso on all samples to capture shared structure:

$$\hat{\beta}^{\text{overall}} = \arg\min_\beta \frac{1}{2n}|y - X\beta|^2 + \lambda|\beta|_1$$

Step 2 — Group models. For each group $k$, fit a Lasso with an offset equal to $\alpha$ times the overall model's linear predictor:

$$\hat{\beta}^{(k)} = \arg\min_\beta \frac{1}{2n_k}|y^{(k)} - \underbrace{\alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}}}_{\text{offset}} - X^{(k)}\beta|^2 + \lambda_k|\beta|_1$$

The parameter $\alpha \in [0, 1]$ controls the pretraining strength:

  • $\alpha = 0$: pure group-specific models, no pretraining
  • $\alpha = 1$: group models explain residuals from the overall model
  • $\alpha \in (0, 1)$: group models are anchored to the overall fit

Final prediction for group $k$: $\hat{y}^{(k)} = \alpha \cdot X^{(k)}\hat{\beta}^{\text{overall}} + X^{(k)}\hat{\beta}^{(k)}$

Supports gaussian, binomial, and multinomial families.


Installation

pip install ptlasso

Requires Python ≥ 3.9 and adelie for the underlying Lasso solver, which supports fitting with offsets (unlike scikit-learn).


Quick start

import numpy as np
from ptlasso import PretrainedLasso, PretrainedLassoCV

rng = np.random.default_rng(42)
n, p, k = 300, 100, 3

X      = rng.standard_normal((n, p))
groups = rng.integers(0, k, size=n)
beta   = np.zeros(p)
beta[:5] = [2, -1.5, 1, -0.8, 0.5]
y      = X @ beta + 0.5 * rng.standard_normal(n)

# Fixed alpha
model = PretrainedLasso(alpha=0.5)
model.fit(X, y, groups)
print(model)
# PretrainedLasso(alpha=0.5, family='gaussian', overall_lambda='lambda.1se', ...)
#   family       : gaussian
#   n_features   : 100
#   n_groups     : 3
#   overall |Ŝ|  : |Ŝ| = 5 / 100  [0, 1, 2, 3, 4]
#   pretrain |Ŝ| : 0: |Ŝ|=5, 1: |Ŝ|=4, 2: |Ŝ|=6

y_pred = model.predict(X, groups)
print("R²:", model.score(X, y, groups))

# Cross-validate over alpha
cv = PretrainedLassoCV(alphas=[0.0, 0.25, 0.5, 0.75, 1.0])
cv.fit(X, y, groups)
print("Best alpha:", cv.alpha_)

Families

# Binary classification
model = PretrainedLasso(alpha=0.5, family="binomial")
model.fit(X, y_binary, groups)
probs = model.predict(X, groups)          # shape (n,), P(y=1)

# Multi-class classification (integer labels 0..K-1)
model = PretrainedLasso(alpha=0.5, family="multinomial")
model.fit(X, y_multiclass, groups)
probs = model.predict(X, groups)          # shape (n, K)

Feature names and group labels

Both fit() methods accept human-readable names. pandas DataFrames are supported natively — column names are picked up automatically.

import pandas as pd

X_df         = pd.DataFrame(X, columns=[f"gene_{i}" for i in range(p)])
group_labels = {0: "control", 1: "treated_A", 2: "treated_B"}

model = PretrainedLasso(alpha=0.5)
model.fit(X_df, y, groups, group_labels=group_labels)
# overall |Ŝ|  : |Ŝ| = 5 / 100  [gene_0, gene_1, gene_2, gene_3, gene_4]
# pretrain |Ŝ| : control: |Ŝ|=5, treated_A: |Ŝ|=4, treated_B: |Ŝ|=6

Inspecting the support

The support is the set of non-zero variables selected by the model, i.e., the features whose coefficients are not shrunk to zero by the L1 regularization and therefore actively contribute to the prediction.

from ptlasso import (
    get_overall_support,
    get_pretrain_support,
    get_pretrain_support_split,
    get_individual_support,
)

get_overall_support(model)                      # features from the overall model
get_pretrain_support(model)                     # union across pretrained group models
get_pretrain_support(model, common_only=True)   # features selected by >50% of groups
get_pretrain_support(model, groups=[0, 1])      # restrict to specific groups
get_individual_support(model)                   # features from no-pretraining baselines

common, indiv = get_pretrain_support_split(model)
# common : features from the overall model (stage 1)
# indiv  : additional features picked up by group models (stage 2)

Evaluating all sub-models at once

result = model.evaluate(X_test, y_test, groups_test)
# {"pretrain":   {"predictions": ..., "score": ...},
#  "individual": {"predictions": ..., "score": ...},
#  "overall":    {"predictions": ..., "score": ...}}

Retrieving coefficients

coefs = model.get_coef()                   # all sub-models
coefs["overall"]                           # {"coef": ndarray, "intercept": ndarray}
coefs["pretrain"]["control"]               # {"coef": ndarray, "intercept": ndarray}
coefs["individual"]["treated_A"]

model.get_coef(model="pretrain")           # just pretrain sub-dict

CV details

cv = PretrainedLassoCV(
    alphas=[0.0, 0.25, 0.5, 0.75, 1.0],
    cv=5,
    alphahat_choice="overall",   # or "mean" (unweighted mean of per-group CV errors)
    family="gaussian",
    overall_lambda="lambda.1se", # or "lambda.min"
    foldid=my_foldid,            # optional: custom integer fold assignments
)
cv.fit(X, y, groups)

cv.alpha_                        # globally best alpha
cv.varying_alphahat_             # {group: best_alpha} per group
cv.cv_results_                   # {alpha: mean CV loss}
cv.cv_results_se_                # {alpha: SE of CV loss}
cv.cv_results_per_group_         # {alpha: {group: mean CV loss}}
cv.cv_results_mean_              # {alpha: unweighted mean of per-group losses}
cv.cv_results_wtd_mean_          # {alpha: size-weighted mean of per-group losses}
cv.cv_results_individual_        # CV loss for individual (no-pretraining) baseline
cv.cv_results_overall_           # CV loss for overall model baseline
cv.best_estimator_               # PretrainedLasso fitted with alpha_
cv.all_estimators_               # {alpha: PretrainedLasso} for varying-alpha prediction

# Predict using each group's own best alpha
cv.predict(X, groups, alphatype="varying")
cv.evaluate(X, y, groups, alphatype="varying")

Plotting

from ptlasso import plot_cv, plot_paths

plot_cv(cv)           # CV loss curve over alpha with ±1 SE band
plot_paths(model)     # regularisation paths for all sub-models

Saving and loading models

PretrainedLasso and PretrainedLassoCV can be serialised with joblib:

import joblib

# Save
joblib.dump(model, "model.pkl")

# Load
model = joblib.load("model.pkl")

Note: serialised models are tied to the Python and library versions used to create them. For long-term storage, pin your dependencies or retrain from scratch after major version upgrades.

Avoid pickle directly — the underlying adelie solver stores C++ objects that are not natively picklable. ptlasso handles this transparently through joblib by converting solver state to plain numpy arrays before serialisation. Using pickle directly bypasses this and will raise a TypeError.


API reference

PretrainedLasso

Parameter Default Description
alpha 0.5 Pretraining strength $\in [0, 1]$
family "gaussian" "gaussian", "binomial", or "multinomial"
overall_lambda "lambda.1se" Lambda rule for stage-1 offset: "lambda.1se" or "lambda.min"
fit_intercept True Fit an intercept in all sub-models
lmda_path_size 100 Number of $\lambda$ values in the regularisation path
min_ratio 0.01 Ratio of smallest to largest $\lambda$
verbose False Show adelie progress bar

Methods:

  • fit(X, y, groups, group_labels=None, feature_names=None)
  • predict(X, groups, model="pretrain", lmda_idx=None)model{"pretrain", "individual", "overall"}
  • score(X, y, groups) — R² or accuracy
  • evaluate(X, y, groups) — predict + score for all three sub-models
  • get_coef(model="all", lmda_idx=None)

PretrainedLassoCV

Parameter Default Description
alphas [0, 0.25, 0.5, 0.75, 1.0] Candidate $\alpha$ values
cv 5 Number of CV folds
alphahat_choice "overall" "overall" or "mean" (unweighted per-group mean)
family "gaussian" Same as PretrainedLasso
overall_lambda "lambda.1se" Same as PretrainedLasso
fit_intercept True
lmda_path_size 100
min_ratio 0.01
verbose False
foldid None Integer array of fold assignments (overrides cv)

Same fit / predict / score / evaluate / get_coef interface as PretrainedLasso, plus:

Fitted attribute Description
alpha_ Best $\alpha$ selected by CV
varying_alphahat_ {group: alpha} — per-group best $\alpha$
cv_results_ {alpha: mean CV loss}
cv_results_se_ {alpha: SE of CV loss}
cv_results_per_group_ {alpha: {group: mean CV loss}}
cv_results_mean_ {alpha: unweighted mean of per-group losses}
cv_results_wtd_mean_ {alpha: size-weighted mean of per-group losses}
cv_results_individual_ CV loss for individual baseline
cv_results_overall_ CV loss for overall baseline
best_estimator_ PretrainedLasso fitted with alpha_
all_estimators_ {alpha: PretrainedLasso} for each unique varying alpha

predict also accepts alphatype="varying" to route each group through its own best alpha.


Citation

@article{craig2025pretraining,
  title   = {Pretraining and the lasso},
  author  = {Craig, Erin and Pilanci, Mert and Le Menestrel, Thomas and Narasimhan, Balasubramanian and Rivas, Manuel A. and Gullaksen, Stein-Erik and Tibshirani, Robert},
  journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
  pages   = {qkaf050},
  year    = {2025}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ptlasso-0.3.7.tar.gz (34.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ptlasso-0.3.7-py3-none-any.whl (32.8 kB view details)

Uploaded Python 3

File details

Details for the file ptlasso-0.3.7.tar.gz.

File metadata

  • Download URL: ptlasso-0.3.7.tar.gz
  • Upload date:
  • Size: 34.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for ptlasso-0.3.7.tar.gz
Algorithm Hash digest
SHA256 a0bf992d4e2ad998287947f5bcd7576341c97093af9c7f2e6e4f399de30b7975
MD5 acd08941e09f296d0b7c7e81a4947dc8
BLAKE2b-256 7137508e3882cc4480a85f9a9df7566fceca66255a10b3dd1b6bf8485d01c710

See more details on using hashes here.

Provenance

The following attestation bundles were made for ptlasso-0.3.7.tar.gz:

Publisher: publish.yml on thomaslemenestrel/ptlasso

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ptlasso-0.3.7-py3-none-any.whl.

File metadata

  • Download URL: ptlasso-0.3.7-py3-none-any.whl
  • Upload date:
  • Size: 32.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for ptlasso-0.3.7-py3-none-any.whl
Algorithm Hash digest
SHA256 632bdbb0b6b8e4249c2e18d190fca49d377dcf843799a4544122013e8d93c5b8
MD5 3d566cd85bb8fb65e9e11c1dacf818fb
BLAKE2b-256 ac71ac40f1cb4c58492324bff2ba934c9ba7a03b2ee3007489a794f8fc2e549f

See more details on using hashes here.

Provenance

The following attestation bundles were made for ptlasso-0.3.7-py3-none-any.whl:

Publisher: publish.yml on thomaslemenestrel/ptlasso

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page