Skip to main content

A package for solving regularised optimisation problems in a scikit-learn style.

Project description

codecov Build Status Documentation Status version downloads

Scikit-Prox

The goal of this project is to implement a set of algorithms for solving the following optimization problem: minimize f(x) + g(x) where f is a smooth function and g is a proximal operator. The proximal operator of a function g is defined as: proxg(λx) = argmin y g(y) + 1/2λ‖y − x‖2

Installation

To install the package, run the following command: pip install scikit-prox

Usage

Example 1: Lasso

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso(alpha=0.1)
model.fit(X, y)
print("scikit-learn solution: {}".format(model.coef_))

# Solve the problem using scikit-prox
model = RegularisedRegression(proximal='L1', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 2: Total Variation Regression

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖∇x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from skprox.linear_model import RegularisedRegression

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-prox
model = RegularisedRegression(proximal='TV', sigma=0.1)
model.fit(X, y)
print("scikit-prox solution: {}".format(model.coef_))

Example 3: Grid Search

The following code solves the following optimization problem: minimize 1/2‖Ax − b‖2 + λ‖x‖1

import numpy as np
from scipy import sparse
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from skprox.linear_model import RegularisedRegression
from sklearn.linear_model import GridSearchCV

# Generate data
X, y = make_regression(n_samples=100, n_features=1000, random_state=0, noise=4.0, bias=100.0)
X = sparse.csr_matrix(X)

# Solve the problem using scikit-learn
model = Lasso()
grid = GridSearchCV(model, {'alpha': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-learn solution: {}".format(grid.best_estimator_.coef_))

# Solve the problem using scikit-prox
model = RegularisedRegression(proximal='L1')
grid = GridSearchCV(model, {'sigma': [0.1, 0.2, 0.3]})
grid.fit(X, y)
print("scikit-prox solution: {}".format(grid.best_estimator_.coef_))

Documentation

The documentation is available at https://scikit-prox.readthedocs.io/en/latest/

License

This project is licensed under the MIT License - see the LICENSE.md file for details

Acknowledgments

This project leans on the pyproximal package borrowing all the proximal operators except for Total Variation which is implemented using functions from skimage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit_prox-0.0.0.tar.gz (6.5 kB view details)

Uploaded Source

File details

Details for the file scikit_prox-0.0.0.tar.gz.

File metadata

  • Download URL: scikit_prox-0.0.0.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for scikit_prox-0.0.0.tar.gz
Algorithm Hash digest
SHA256 12706a481374b3ad958ea70c123695836f8934bc8ea1c315c82585bc034e31ea
MD5 3c2f6361406f4c989e4c46d875be8fbd
BLAKE2b-256 a152c1f80f9dae0b9206b448dc9de1fd014264bbb9d0c4f2d39ca5d0cb76de82

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page