Skip to main content

PyTorch implementation of group elastic net

Project description

torch-gel

This package provides PyTorch implementations to solve the group elastic net problem. Let Aj (j = 1 … p) be feature matrices of sizes m × nj (m is the number of samples, and nj is the number of features in the jth group), and let y be an m × 1 vector of the responses. Group elastic net finds coefficients βj, and a bias β0 that solve the optimization problem

min β0, …, βp ½ ║y - β0 - ∑ Aj βj2 + m ∑ √nj1║βj + λ2║βj2).

Here λ1 and λ2 are scalar coefficients that control the amount of 2-norm and squared 2-norm regularization. This 2-norm regularization encourages sparsity at the group level; entire βj might become 0. The squared 2-norm regularization is in similar spirit to elastic net, and addresses some of the issues of lasso. Note that group elastic net includes as special cases group lasso (λ2 = 0), ridge regression (λ1 = 0), elastic net (each nj = 1), and lasso (each nj = 1 and λ2 = 0). The optimization problem is convex, and can be solved efficiently. This package provides two implementations; one based on proximal gradient descent, and one based on coordinate descent.

Installation

Install with pip

pip install torchgel

tqdm (for progress bars), and numpy are pulled in as dependencies. PyTorch (v1.0+) is also needed, and needs to be installed manually. Refer to the PyTorch website for instructions.

Usage

examples/main.ipynb is a Jupyter notebook that walks through using the package for a typical use-case. A more formal description of the functions follows; and for details about the algorithms, refer to the docstrings of files in the gel directory.

Solving Single Instances

The modules gel.gelfista and gel.gelcd provide implementations based on proximal gradient descent and coordinate descent respectively. Both have similar interfaces, and expose two main public functions: make_A and gel_solve. The feature matrices should be stored in a list (say As) as PyTorch tensor matrices, and the responses should be stored in a PyTorch vector (say y). Additionally, the sizes of the groups (nj) should be stored in a vector (say ns). First use the make_A function to convert the feature matrices into a suitable format:

A = make_A(As, ns)

Then pass A, y and other required arguments to gel_solve. The general interface is::

b_0, B = gel_solve(A, y, l_1, l_2, ns, **kwargs)

l_1 and l_2 are floats representing λ1 and λ2 respectively. The method returns a float b_0 representing the bias and a PyTorch matrix B holding the other coefficients. B has size p × maxj nj with suitable zero padding. The following sections cover additional details for the specific implementations.

Proximal Gradient Descent (FISTA)

The gel.gelfista module contains a proximal gradient descent implementation. It's usage is just as described in the template above. Refer to the docstring for gel.gelfista.gel_solve for details about the other arguments.

Coordinate Descent

The gel.gelcd module contains a coordinate descent implementation. Its usage is a bit more involved than the FISTA implementation. Coordinate descent iteratively solves single blocks (each corresponding to a single βj). There are multiple solvers provided to solve the individual blocks. These are the gel.gelcd.block_solve_* functions. Refer to their docstrings for details about their arguments. gel.gelcd.gel_solve requires passing a block solve function and its arguments (as a dictionary). Refer to its docstring for further details.

Solution Paths

gel.gelpaths provides a wrapper function gel_paths to solve the group elastic net problem for multiple values of the regularization coefficients. It implements a two-stage process. For a given λ1 and λ2, first the group elastic net problem is solved and the feature blocks with non-zero coefficients is extracted (the support). Then ridge regression models are learned for each of several provided regularization values. The final model is summarized using an arbitrary provided summary function, and the summary for each combination of the regularization values is returned as a dictionary. The docstring contains more details. gel.ridgepaths contains another useful function, ridge_paths which can efficiently solve ridge regression for multiple regularization values.

Citation

If you find this code useful in your research, please cite

@misc{koushik2017torchgel,
  author = {Koushik, Jayanth},
  title = {torch-gel},
  year = {2017},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/jayanthkoushik/torch-gel}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgel-2.0.0.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

torchgel-2.0.0-py3-none-any.whl (15.6 kB view details)

Uploaded Python 3

File details

Details for the file torchgel-2.0.0.tar.gz.

File metadata

  • Download URL: torchgel-2.0.0.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.2 CPython/3.7.9 Linux/5.4.0-1026-azure

File hashes

Hashes for torchgel-2.0.0.tar.gz
Algorithm Hash digest
SHA256 9fde90d0cc97ee5ef9a16beba9c69b12eba96aa67117df7daf6ccb783b2cf6b6
MD5 db9e0711a8b86cc924d87d9c6ad8a8b3
BLAKE2b-256 8221d456cc3e499ad7e0025cd96ee06f41ef7df4b9b2991f8de685f54523c936

See more details on using hashes here.

File details

Details for the file torchgel-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: torchgel-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 15.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.2 CPython/3.7.9 Linux/5.4.0-1026-azure

File hashes

Hashes for torchgel-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2be8c2d307ef5f96a4340cc42cb8731b264e857d7387f4728486e512f74e7c6d
MD5 8fa1b9c4b02b06194d0e5f8efdd87ad0
BLAKE2b-256 bedea234bbde8e11e36aaad955b98dbefa43de29edb6fb7ef067581b055402bc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page