A Robust ML toolbox
Project description
CI | Test | |
Style | ||
Doc | ||
Doc | Readthedocs | |
Checks | Code style | |
Types | ||
Build | ||
Install | Pip | |
Conda | ||
Github | ||
Cite |
SkWDRO - Wasserstein Distributionaly Robust Optimization
Model robustification with thin interface
You can make pigs fly, [Kolter&Madry, 2018]
skwdro
is a Python package that offers WDRO versions for a large range of estimators, either by extending scikit-learn
estimator or by providing a wrapper for pytorch
modules.
Have a look at skwdro
documentation!
Getting started with skwdro
Installation
Development mode with hatch
First install hatch
and clone the archive. In the root folder, make shell
gives you an interactive shell in the correct environment and make test
runs the tests (it can be launched from both an interactive shell and a normal shell).
make reset_env
removes installed environments (useful in case of troubles).
With pip
skwdro
will be available on PyPi soon, for now only the development mode is available.
First steps with skwdro
scikit-learn
interface
Robust estimators from skwdro
can be used as drop-in replacements for scikit-learn
estimators (they actually inherit from scikit-learn
estimators and classifier classes.). skwdro
provides robust estimators for standard problems such as linear regression or logistic regression. LinearRegression
from skwdro.linear_model
is a robust version of LinearRegression
from scikit-learn
and be used in the same way. The only difference is that now an uncertainty radius rho
is required.
We assume that we are given X_train
of shape (n_train, n_features)
and y_train
of shape (n_train,)
as training data and X_test
of shape (n_test, n_features)
as test data.
from skwdro.linear_model import LinearRegression
# Uncertainty radius
rho = 0.1
# Fit the model
robust_model = LinearRegression(rho=rho)
robust_model.fit(X_train, y_train)
# Predict the target values
y_pred = robust_model.predict(X_test)
You can refer to the documentation to explore the list of skwdro
's already-made estimators.
pytorch
interface
Didn't find a estimator that suits you? You can compose your own using the pytorch
interface: it allows more flexibility, custom models and optimizers.
Assume now that the data is given as a dataloader train_loader
.
import torch
import torch.nn as nn
import torch.optim as optim
from skwdro.torch import robustify
# Uncertainty radius
rho = 0.1
# Define the model
model = nn.Linear(n_features, 1)
# Define the loss function
loss_fn = nn.MSELoss()
# Define a sample batch for initialization
sample_batch_x, sample_batch_y = next(iter(train_loader))
# Robust loss
robust_loss = robustify(loss_fn, model, rho, sample_batch_x, sample_batch_y)
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Training loop
for epoch in range(100):
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
loss = robust_loss(model(batch_x), batch_y)
loss.backward()
optimizer.step()
You will find detailed description on how to robustify
modules in the documentation.
Cite
skwdro
is the result of a research project. It is licensed under BSD 3-Clause. You are free to use it and if you do so, please cite
@article{vincent2024skwdro,
title={skwdro: a library for Wasserstein distributionally robust machine learning},
author={Vincent, Florian and Azizian, Wa{\"\i}ss and Iutzeler, Franck and Malick, J{\'e}r{\^o}me},
journal={arXiv preprint arXiv:2410.21231},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file skwdro-1.0.3.tar.gz
.
File metadata
- Download URL: skwdro-1.0.3.tar.gz
- Upload date:
- Size: 48.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.25.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 90fabe27e90e16284d1e7f7fdda00e58bcc1d2ca39be6cd35d23744063076d2c |
|
MD5 | 22c6e50aaff7f129abca91975aa0e37c |
|
BLAKE2b-256 | 25d91d1cae9a7a6ab2f714ef2e0fa64809e9752d334e862102e7398575e92c54 |
File details
Details for the file skwdro-1.0.3-py2.py3-none-any.whl
.
File metadata
- Download URL: skwdro-1.0.3-py2.py3-none-any.whl
- Upload date:
- Size: 75.3 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.25.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e7dbc3eac9e922e05b00b7d728a736e0ef1688649ae1e393000892ffc64153c |
|
MD5 | 4672df3e9f21c9ef725a17ce1bc8f830 |
|
BLAKE2b-256 | 8a5c7bbe9015de5842e242827ba906671e016e7682fa9e18d84cf36a36e8db8d |