Skip to main content

A Robust ML toolbox

Project description

CI Test Workflow Test
Style Workflow Style
Doc Workflow Doc
Doc Readthedocs
Checks Code style
Types
Build
Install Pip PyPI - Python Version
Conda
Github
Cite

SkWDRO - Wasserstein Distributionaly Robust Optimization

Model robustification with thin interface

You can make pigs fly, [Kolter&Madry, 2018]

Python PyTorch Scikit Learn License

skwdro is a Python package that offers WDRO versions for a large range of estimators, either by extending scikit-learn estimator or by providing a wrapper for pytorch modules.

Have a look at skwdro documentation!

Getting started with skwdro

Installation

Development mode with hatch

First install hatch and clone the archive. In the root folder, make shell gives you an interactive shell in the correct environment and make test runs the tests (it can be launched from both an interactive shell and a normal shell). make reset_env removes installed environments (useful in case of troubles).

With pip

skwdro will be available on PyPi soon, for now only the development mode is available.

First steps with skwdro

scikit-learn interface

Robust estimators from skwdro can be used as drop-in replacements for scikit-learn estimators (they actually inherit from scikit-learn estimators and classifier classes.). skwdro provides robust estimators for standard problems such as linear regression or logistic regression. LinearRegression from skwdro.linear_model is a robust version of LinearRegression from scikit-learn and be used in the same way. The only difference is that now an uncertainty radius rho is required.

We assume that we are given X_train of shape (n_train, n_features) and y_train of shape (n_train,) as training data and X_test of shape (n_test, n_features) as test data.

from skwdro.linear_model import LinearRegression

# Uncertainty radius
rho = 0.1

# Fit the model
robust_model = LinearRegression(rho=rho)
robust_model.fit(X_train, y_train)

# Predict the target values
y_pred = robust_model.predict(X_test)

You can refer to the documentation to explore the list of skwdro's already-made estimators.

pytorch interface

Didn't find a estimator that suits you? You can compose your own using the pytorch interface: it allows more flexibility, custom models and optimizers.

Assume now that the data is given as a dataloader train_loader.

import torch
import torch.nn as nn
import torch.optim as optim

from skwdro.torch import robustify

# Uncertainty radius
rho = 0.1

# Define the model
model = nn.Linear(n_features, 1)

# Define the loss function
loss_fn = nn.MSELoss()

# Define a sample batch for initialization
sample_batch_x, sample_batch_y = next(iter(train_loader))

# Robust loss
robust_loss = robustify(loss_fn, model, rho, sample_batch_x, sample_batch_y)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        loss = robust_loss(model(batch_x), batch_y)
        loss.backward()
        optimizer.step()

You will find detailed description on how to robustify modules in the documentation.

Cite

skwdro is the result of a research project. It is licensed under BSD 3-Clause. You are free to use it and if you do so, please cite

@article{vincent2024skwdro,
  title={skwdro: a library for Wasserstein distributionally robust machine learning},
  author={Vincent, Florian and Azizian, Wa{\"\i}ss and Iutzeler, Franck and Malick, J{\'e}r{\^o}me},
  journal={arXiv preprint arXiv:2410.21231},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skwdro-1.0.3.tar.gz (48.2 kB view details)

Uploaded Source

Built Distribution

skwdro-1.0.3-py2.py3-none-any.whl (75.3 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file skwdro-1.0.3.tar.gz.

File metadata

  • Download URL: skwdro-1.0.3.tar.gz
  • Upload date:
  • Size: 48.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.25.0

File hashes

Hashes for skwdro-1.0.3.tar.gz
Algorithm Hash digest
SHA256 90fabe27e90e16284d1e7f7fdda00e58bcc1d2ca39be6cd35d23744063076d2c
MD5 22c6e50aaff7f129abca91975aa0e37c
BLAKE2b-256 25d91d1cae9a7a6ab2f714ef2e0fa64809e9752d334e862102e7398575e92c54

See more details on using hashes here.

File details

Details for the file skwdro-1.0.3-py2.py3-none-any.whl.

File metadata

  • Download URL: skwdro-1.0.3-py2.py3-none-any.whl
  • Upload date:
  • Size: 75.3 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.25.0

File hashes

Hashes for skwdro-1.0.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 0e7dbc3eac9e922e05b00b7d728a736e0ef1688649ae1e393000892ffc64153c
MD5 4672df3e9f21c9ef725a17ce1bc8f830
BLAKE2b-256 8a5c7bbe9015de5842e242827ba906671e016e7682fa9e18d84cf36a36e8db8d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page