Reference implementation of LassoNet
Project description
LassoNet
LassoNet is a new family of models to incorporate feature selection and neural networks.
LassoNet works by adding a linear skip connection from the input features to the output. A L1 penalty (LASSO-inspired) is added to that skip connection along with a constraint on the network so that whenever a feature is ignored by the skip connection, it is ignored by the whole network.
Installation
pip install lassonet
Usage
We have designed the code to follow scikit-learn's standards to the extent possible (e.g. linear_model.Lasso).
from lassonet import LassoNetClassifierCV
model = LassoNetClassifierCV() # LassoNetRegressorCV
path = model.fit(X_train, y_train)
print("Best model scored", model.score(X_test, y_test))
print("Lambda =", model.best_lambda_)
You should always try to give normalized data to LassoNet as it uses neural networks under the hood.
You can read the full documentation or read the examples that cover all features.
Features
- regression, classification and Cox regression with
LassoNetRegressor
,LassoNetClassifier
andLassoNetCoxRegressor
. - cross-validation with
LassoNetRegressorCV
,LassoNetClassifierCV
andLassoNetCoxRegressorCV
- group feature selection with the
groups
argument lambda_start="auto"
heuristic
Note that cross-validation, group feature selection and automatic lambda_start
selection have not been published in papers, you can read the code or post as issue to request more details.
We are currently working (among others) on adding support for convolution layers, auto-encoders and online logging of experiments.
Cross-validation
The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the .fit()
method useless since the resulting model is always empty. This feature is still available with the .path()
method for any model or the lassonet_path
function and returns a list of checkpoints that can be loaded with .load()
.
Since then, we integrated support for cross-validation (5-fold by default) in the estimators whose name ends with CV
. For each fold, a path is trained. The best regularization value is then chosen to maximize the average performance over all folds. The model is then retrained on the whole training dataset to reach that regularization.
Website
LassoNet's website is https://lassonet.ml. It contains many useful references including the paper, live talks and additional documentation.
References
- Lemhadri, Ismael, Feng Ruan, Louis Abraham, and Robert Tibshirani. "LassoNet: A Neural Network with Feature Sparsity." Journal of Machine Learning Research 22, no. 127 (2021). pdf bibtex
- Yang, Xuelin, Louis Abraham, Sejin Kim, Petr Smirnov, Feng Ruan, Benjamin Haibe-Kains, and Robert Tibshirani. "FastCPH: Efficient Survival Analysis for Neural Networks." arXiv preprint arXiv:2208.09793 (2022). pdf
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lassonet-0.0.14.tar.gz
.
File metadata
- Download URL: lassonet-0.0.14.tar.gz
- Upload date:
- Size: 16.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.13.0 pkginfo/1.7.0 requests/2.28.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2118c748cf827fa6e9d3fa59192cd376f602e67039100ab5201da6ed8f3cf677 |
|
MD5 | 7bec61c7b6cd7b9db81958f7c284e6ae |
|
BLAKE2b-256 | 89dace1f69d40396624e60582da8f76657de3a6a6cf3626ffe6ac8447c61e6f7 |
File details
Details for the file lassonet-0.0.14-py3-none-any.whl
.
File metadata
- Download URL: lassonet-0.0.14-py3-none-any.whl
- Upload date:
- Size: 17.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.13.0 pkginfo/1.7.0 requests/2.28.1 requests-toolbelt/0.9.1 tqdm/4.64.0 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2d1ab682d46391bdd49c6dd5b06f4c0ebfa782ec2b683999a58f584ae602dff4 |
|
MD5 | 7f5f2efc227b4ba4f85278a69bf04632 |
|
BLAKE2b-256 | 7896b0208161fbcabbfe133c4695d8744ef228df2befdff0441c794efde7b845 |