Skip to main content

A package for solving regularised optimisation problems in a scikit-learn style.

Project description

codecov Build Status Documentation Status version downloads

Scikit-Prox

The goal of this project is to implement a set of algorithms for solving the following optimization problem: minimize f(x) + g(x) where f is a smooth function and g is a proximal operator. The proximal operator of a function g is defined as: proxg(λx) = argmin y g(y) + 1/2λ‖y − x‖2

Installation

To install the package, run the following command: pip install scikit-prox

Usage

Example 1: Lasso

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedLinearRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso(alpha=0.1)
model.fit(X, y)
print("scikit-learn solution: {}".format(model.coef_))

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='L1', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 2: Total Variation Regression

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖∇x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from skprox.linear_model import RegularisedLinearRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='TV', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 3: Grid Search

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedLinearRegression
from sklearn.model_selection import GridSearchCV

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso()
grid = GridSearchCV(model, {'alpha': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-learn solution: {}".format(grid.best_estimator_.coef_))

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='L1')
grid = GridSearchCV(model, {'sigma': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-prox solution: {}".format(grid.best_estimator_.coef_))

Documentation

The documentation is available at https://scikit-prox.readthedocs.io/en/latest/

License

This project is licensed under the MIT License - see the LICENSE.md file for details

Acknowledgments

This project leans on the pyproximal package borrowing all the proximal operators except for Total Variation which is implemented using functions from skimage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_prox-0.0.2.tar.gz (7.8 kB view details)

Uploaded Source

File details

Details for the file scikit_prox-0.0.2.tar.gz.

File metadata

  • Download URL: scikit_prox-0.0.2.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for scikit_prox-0.0.2.tar.gz
Algorithm Hash digest
SHA256 b7702322899457745d901ed939bc2dc3d9451286645c0f104986a6b181b2fe0c
MD5 190b852b216036c7f992c75ca4f09fd2
BLAKE2b-256 f9d9ce05c2422e6e083ca8a752af40a566c06520fd444c542c7f216cddec2215

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page