Skip to main content

A package for solving regularised optimisation problems in a scikit-learn style.

Project description

codecov Build Status Documentation Status version downloads

Scikit-Prox

The goal of this project is to implement a set of algorithms for solving the following optimization problem: minimize f(x) + g(x) where f is a smooth function and g is a proximal operator. The proximal operator of a function g is defined as: proxg(λx) = argmin y g(y) + 1/2λ‖y − x‖2

Installation

To install the package, run the following command: pip install scikit-prox

Usage

Example 1: Lasso

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedLinearRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso(alpha=0.1)
model.fit(X, y)
print("scikit-learn solution: {}".format(model.coef_))

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='L1', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 2: Total Variation Regression

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖∇x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from skprox.linear_model import RegularisedLinearRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='TV', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 3: Grid Search

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedLinearRegression
from sklearn.model_selection import GridSearchCV

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso()
grid = GridSearchCV(model, {'alpha': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-learn solution: {}".format(grid.best_estimator_.coef_))

# Solve the problem using scikit-prox
model = RegularisedLinearRegression(proximal='L1')
grid = GridSearchCV(model, {'sigma': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-prox solution: {}".format(grid.best_estimator_.coef_))

Documentation

The documentation is available at https://scikit-prox.readthedocs.io/en/latest/

License

This project is licensed under the MIT License - see the LICENSE.md file for details

Acknowledgments

This project leans on the pyproximal package borrowing all the proximal operators except for Total Variation which is implemented using functions from skimage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_prox-0.0.5.tar.gz (9.0 kB view details)

Uploaded Source

File details

Details for the file scikit_prox-0.0.5.tar.gz.

File metadata

  • Download URL: scikit_prox-0.0.5.tar.gz
  • Upload date:
  • Size: 9.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for scikit_prox-0.0.5.tar.gz
Algorithm Hash digest
SHA256 9d6c76072c0488e72d94aaad38e00e40ba20fe2aa93b905c2319d791889a2d43
MD5 573a1379412ff48ccb34b530e8786fc5
BLAKE2b-256 ee7afd677f75f467c3ec7cbe5fe451276765b429e9bd2efa3d70ba076c2bbccf

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page