EASIER-net
Project description
EASIER-net
Feng, Jean, and Noah Simon. 2022. “Ensembled Sparse‐input Hierarchical Networks for High‐dimensional Datasets.” Statistical Analysis and Data Mining, March. https://doi.org/10.1002/sam.11579.
Python code for fitting EASIER-nets and reproducing all results from the paper. The python code uses PyTorch.
R code for fitting EASIER-net is available at https://github.com/jjfeng/easier_net_R.
Quick-start
Setup a python virtual environment (code runs for python 3.6) with the appropriate packages from requirements.txt
.
Simulate data using by following the tutorial notebook or load your own into a npz
format with x
and y
attributes. You may also perform GridSearchCV by following the tutorial.
To fit an EASIER-net, run
python fit_easier_net.py --n-estimators <N_ESTIMATORS> --input-filter-layer <INPUT_FILTER_LAYER> --n-layers <N_LAYERS> --n-hidden <N_HIDDEN> --input-pen <INPUT_PEN> --full-tree-pen <FULL_TREE_PEN> --batch-size <BATCH_SIZE> --num-classes <NUM_CLASSES> --weight <WEIGHT> --max-iters <MAX_ITERS> --max-prox-iters <MAX_PROX_ITERS> --model-fit-params-file <MODEL_FIT_PARAMS_FILE>
where:
N_ESTIMATORS
should be size of ensemble; the number of SIER-nets being ensembled.INPUT_FILTER_LAYER
is whether to scale the inputs by parameter βN_LAYERS
is the number of hidden layersN_HIDDEN
is the number of hidden nodes per layerINPUT_PEN
specifies $\lambda_1$ in the paper; controls the input sparsityFULL_TREE_PEN
specifies $\lambda_2$ in the paper; controls the number of active layers and hidden nodesBATCH_SIZE
specifies the size of the mini-batches for AdamNUM_CLASSES
should be 0 if doing regression andNUM_CLASSES
should be the number of classes if doing binary/multi-classificationWEIGHT
is a list of weights for the classesMAX_ITERS
is the number of epochs to run AdamMAX_PROX_ITERS
is the number of epochs to run batch proximal gradient descentMODEL_FIT_PARAMS_FILE
is a json file that specifies what the hyperparameters are. If given, this will override the arguments passed in.
To perform cross-validation, one should run separate fit_easier_net.py
scripts for each candidate penalty parameter values.
Then select the best penalty parameter values using collate_best_param.py
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for EASIER_net-0.0.8-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f04b3635242a51127460fe01c357e61528d6b1ffbb744c6e590064d9da8d8d84 |
|
MD5 | ae9cffa3d061291316cac227f7658f77 |
|
BLAKE2b-256 | b60b7e2ef9f7801ee2b4317687f455fbef4fa3ace673cbf893dbfb4275dc5730 |