Skip to main content

Reference implementation of LassoNet

Project description

PyPI version Downloads

LassoNet

LassoNet is a new family of models to incorporate feature selection and neural networks.

LassoNet works by adding a linear skip connection from the input features to the output. A L1 penalty (LASSO-inspired) is added to that skip connection along with a constraint on the network so that whenever a feature is ignored by the skip connection, it is ignored by the whole network.

Promo Video

Installation

pip install lassonet

Usage

We have designed the code to follow scikit-learn's standards to the extent possible (e.g. linear_model.Lasso).

from lassonet import LassoNetClassifierCV 
model = LassoNetClassifierCV() # LassoNetRegressorCV
path = model.fit(X_train, y_train)
print("Best model scored", model.score(X_test, y_test))
print("Lambda =", model.best_lambda_)

You should always try to give normalized data to LassoNet as it uses neural networks under the hood.

You can read the full documentation or read the examples that cover most features.

Features

  • regression, classification, Cox regression and interval-censored Cox regression with LassoNetRegressor, LassoNetClassifier, LassoNetCoxRegressor and LassoNetIntervalRegressor.
  • cross-validation with LassoNetRegressorCV, LassoNetClassifierCV, LassoNetCoxRegressorCV and LassoNetIntervalRegressorCV
  • stability selection with model.stability_selection()
  • group feature selection with the groups argument
  • lambda_start="auto" heuristic (default)

Note that cross-validation, group feature selection and automatic lambda_start selection have not been published in papers, you can read the code or post as issue to request more details.

We are currently working (among others) on adding support for convolution layers, auto-encoders and online logging of experiments.

Cross-validation

The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the .fit() method useless since the resulting model is always empty. This feature is still available with the .path() method for any model or the lassonet_path function and returns a list of checkpoints that can be loaded with .load().

Since then, we integrated support for cross-validation (5-fold by default) in the estimators whose name ends with CV. For each fold, a path is trained. The best regularization value is then chosen to maximize the average performance over all folds. The model is then retrained on the whole training dataset to reach that regularization.

Website

LassoNet's website is https:lasso-net.github.io/. It contains many useful references including the paper, live talks and additional documentation.

References

  • Lemhadri, Ismael, Feng Ruan, Louis Abraham, and Robert Tibshirani. "LassoNet: A Neural Network with Feature Sparsity." Journal of Machine Learning Research 22, no. 127 (2021). pdf bibtex
  • Yang, Xuelin, Louis Abraham, Sejin Kim, Petr Smirnov, Feng Ruan, Benjamin Haibe-Kains, and Robert Tibshirani. "FastCPH: Efficient Survival Analysis for Neural Networks." In NeurIPS 2022 Workshop on Learning from Time Series for Health. pdf
  • Meixide, Carlos García, Marcos Matabuena, Louis Abraham, and Michael R. Kosorok. "Neural interval‐censored survival regression with feature selection." Statistical Analysis and Data Mining: The ASA Data Science Journal 17, no. 4 (2024): e11704. pdf

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lassonet-0.0.17.tar.gz (19.6 kB view details)

Uploaded Source

Built Distribution

lassonet-0.0.17-py3-none-any.whl (19.2 kB view details)

Uploaded Python 3

File details

Details for the file lassonet-0.0.17.tar.gz.

File metadata

  • Download URL: lassonet-0.0.17.tar.gz
  • Upload date:
  • Size: 19.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for lassonet-0.0.17.tar.gz
Algorithm Hash digest
SHA256 d5e5752f3ec9e73965e09cf3661fbb2b8d026285e998b4cb85b45e2e73486726
MD5 c41730a7df499bb0bc99c13d0334d8b5
BLAKE2b-256 284a3fdb08ab3c44bc8b4d627d7e02162688d71d43bca78eb3bc403838fafdc7

See more details on using hashes here.

File details

Details for the file lassonet-0.0.17-py3-none-any.whl.

File metadata

  • Download URL: lassonet-0.0.17-py3-none-any.whl
  • Upload date:
  • Size: 19.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.4

File hashes

Hashes for lassonet-0.0.17-py3-none-any.whl
Algorithm Hash digest
SHA256 2a46f4846937ed42f9bf243c9311f98df559fffaf170abf5f8151f93238559ce
MD5 95ea123bb98b8c3a85e77388f36af0a0
BLAKE2b-256 1055274ca0888bff8e40e708828adf9aec0d8a868ddd9222d6e70255928bc704

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page