Skip to main content

machine learning package

Project description


XCurve: Machine Learning with Decision-Invariant X-Curve Metrics

Mission: Support end-to-end Training Solutions for Decision Invariant Models

Please visit the website for more details on XCurve!


Latest News

  • (New!) 2022.6: The XCurve-v1.1.0 has been released! Please Try now!

Introduction

Recently, machine learning and deep learning technologies have been successfully employed in many complicated high-stake decision-making applications such as disease prediction, fraud detection, outlier detection, and criminal justice sentencing. All these applications share a common trait known as risk-aversion in economics and finance terminologies. In other words, the decision-makers tend to have an extremely low risk tolerance. Under this context, decision-making parameters will significantly affect the performance of models. For example, in binary classification problems, we use the so-called classification threshold as the decision parameter. In the following examples, we see that changing the threshold leads to significantly different model performances.

In risk-aversion problems, the decision parameters change dynamically in deployment time. Hence, the goal of X-curve learning is to learn high-quality models that can adapt to different decision conditions. Inspired by the fundamental principle of the well-known AUC optimization, our library provides a systematic solution to optimize the area under different kinds of performance curves. To be more specific, the performance curve is formed by a plot of two performance functions $x(\lambda), y(\lambda)$ of decision parameter $\lambda$. The area under a performance curve becomes the integral of the performance over all possible choices of different decision conditions. In this way, the learning systems are only required to optimize a decision-invariant metric to avoid the risk aversion issue.

XCurve now supports four kinds of performance curves including AUROC for Long-tail Recognition, AUPRC for Imbalanced Retrieval, AUTKC for Classification under Ambiguity, and OpenAUC for Open-Set Recognition.

Outline

The core functions of this library includes the following contents:

Wide Real-World Applications

There is a wide range of applications for XCurve in the real world, especially the data following a long-tailed/imbalanced distribution. Several cases are listed below:

Supported Curves in XCurve

X-Curve Description
XCurve.AUROC an efficient optimization library for Area Under the ROC curve (AUROC).
XCurve.AUPRC an efficient optimization library for Area Under the Precision-Recall curve (AUPRC).
XCurve.AUTKC an efficient optimization library for Area Under the Top-K curve (AUPRC).
XCurve.OpenAUC an efficient optimization library for Area Under the Open ROC curve (OpenAUC).
... ...

More X-Curves are stepping up the development. Please stay tuned!

Installation

You can get XCurve by

pip install XCurve

Quickstart

Let us take the multi-class AUROC optimization as an example curve here. Detailed tutorial could be found in the website (https://xcurveopt.github.io/).

'''
We refer the reader to see our paper <Learning with Multiclass AUC: Theory and Algorithms>
if they are interested in the technical details of this example. 
'''
import torch
from easydict import EasyDict as edict
import torch
import random
import numpy as np

from XCurve.AUROC.dataloaders import get_datasets # dataset of Xcurve
from XCurve.AUROC.dataloaders import get_data_loaders # dataloader of Xcurve
from XCurve.AUROC.losses import SquareAUCLoss # loss of AUROC
from torch.optim import SGD # optimier (or one can use any optimizer supported by PyTorch)
from XCurve.AUROC.models import generate_net # create model or you can adopt any DNN models by Pytorch

seed = 1024
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# set params to create model
args = edict({
    "model_type": "resnet18", # (support resnet, densenet121 and mlp)
    "num_classes": 10, # number of class
    "pretrained": None # if the model is pretrained
})
# Or you can adopt any DNN models by Pytorch
model = generate_net(args).cuda() # generate pytorch model 

num_classes = 10
criterion = SquareAUCLoss(
    num_classes=num_classes, # number of classes
    gamma=1.0, # safe margin
    transform="ovo" # the manner of computing the multi-classes AUROC Metric ('ovo' or 'ova').
) # create loss criterion
optimizer = SGD(model.parameters(), lr=0.01) # create optimizer

# set dataset params, see our doc. for more details.
dataset_args = edict({
    "data_dir": "cifar-10-long-tail/", # relative path of dataset
    "input_size": [32, 32],
    "norm_params": {
        "mean": [123.675, 116.280, 103.530],
        "std": [58.395, 57.120, 57.375]
        },
    "use_lmdb": True,
    "resampler_type": "None",
    "npy_style": True,
    "aug": True, 
    "num_classes": num_classes
})

train_set, val_set, test_set = get_datasets(dataset_args) # load dataset
trainloader, valloader, testloader = get_data_loaders(
    train_set,
    val_set,
    test_set,
    train_batch_size=32,
    test_batch_size =64
) # load dataloader
# Note that, in the get_datasets(), we conduct stratified sampling for train_set  
# using the StratifiedSampler at from XCurve.AUROC.dataloaders import StratifiedSampler

# forward of model for one epoch
for index, (x, target) in enumerate(trainloader):
    x, target  = x.cuda(), target.cuda()
    # target.shape => [batch_size, ]
    # Note that we ask for the prediction of the model among [0,1] 
    # for any binary (i.e., sigmoid) or multi-class (i.e., softmax) AUROC optimization.
    
    # forward
    pred = torch.sigmoid(model(x)) # [batch_size, num_classess] when num_classes > 2, o.w. output [batch_size, ] 
    loss = criterion(pred, target)
    if index % 30 == 0:
        print("loss:", loss.item())
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Contact & Contribution

If you find any issues or plan to contribute back bug-fixes, please contact us by Shilong Bao (Email: baoshilong@iie.ac.cn) or Zhiyong Yang (Email: yangzhiyong21@ucas.ac.cn)

The authors appreciate all contributions!

Citation

Please cite our paper if you use this library in your own work:

@inproceedings{DBLP:conf/icml/YQBYXQ, 
author    = {Zhiyong Yang, Qianqian Xu, Shilong Bao, Yuan He, Xiaochun Cao and Qingming Huang},
  title     = {When All We Need is a Piece of the Pie: A Generic Framework for Optimizing Two-way Partial AUC},
  booktitle = {ICML},
  pages     = {11820--11829},
  year      = {2021}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

XCurve-1.1.0.tar.gz (72.7 kB view details)

Uploaded Source

Built Distribution

XCurve-1.1.0-py3-none-any.whl (97.1 kB view details)

Uploaded Python 3

File details

Details for the file XCurve-1.1.0.tar.gz.

File metadata

  • Download URL: XCurve-1.1.0.tar.gz
  • Upload date:
  • Size: 72.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for XCurve-1.1.0.tar.gz
Algorithm Hash digest
SHA256 b0b5de01961d14c851eeb4fbb1de347aec6468c279b8bc10d170956fccd80438
MD5 6d6a7ae64b72c5d6fed421bc3e715320
BLAKE2b-256 17ea8ebb9d9d603eeabc4007eb46b7cb81941f39133a0d97de67a0c9b16ace5e

See more details on using hashes here.

File details

Details for the file XCurve-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: XCurve-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 97.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for XCurve-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1cbd37733f9162490e4c5546fc867b031085b338e0e95be91ed1631220a108c6
MD5 ca72ffed571d45a63c5673b2245c58cb
BLAKE2b-256 846d959a3515c8a69e8549e903faa536379a6dcc9c634c277b9af609831aabee

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page