Skip to main content

Learning from Label Proportions (LLP) methods in Python

Project description

llp-learn

PyPI - Version PyPI - Python Version

LLP-learn is a library that provides implementation of methods for Learning from Label Proportions.


Table of Contents

Installation

pip install llp-learn

Usage

import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from llp_learn.dllp import DLLP
from llp_learn.model_selection import gridSearchCV

random = np.random.RandomState(42)

# Creating a syntetic dataset using sklearn
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2, n_clusters_per_class=1, n_samples=1000, random_state=42)

# Generating 5 bags randomly
bags = random.randint(0, 5, size=X.shape[0])

# Creating the proportions
proportions = np.zeros(5)
for i in range(5):
    bag_i = np.where(bags == i)[0]
    proportions[i] = y[bag_i].sum() / len(bag_i)

# LLP model (DLLP)
llp_model = DLLP(model_type="simple-mlp", lr=0.0001, n_epochs=1000, hidden_layer_sizes=(100, 100), n_jobs=0)

# Grid Search the lr parameter
gs = gridSearchCV(llp_model, param_grid={"lr": [0.1, 0.01, 0.001, 0.0001]}, cv=5, validation_size=0.5, n_jobs=1, random_state=42)

# Train/test split
train_idx = random.choice(np.arange(X.shape[0]), size=int(X.shape[0] * 0.8), replace=False)
test_idx = np.setdiff1d(np.arange(X.shape[0]), train_idx)

# Fitting the model
gs.fit(X[train_idx], bags[train_idx], proportions)

# Predicting the labels of the test set
y_pred_test = gs.predict(X[test_idx])

# Reporting the performance of the model in the test set
print(classification_report(y[test_idx], y_pred_test))

License

llp-learn is distributed under the terms of the MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llp_learn-1.5.0.tar.gz (26.2 kB view hashes)

Uploaded Source

Built Distribution

llp_learn-1.5.0-py3-none-any.whl (39.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page