Skip to main content

Multiple Instance Learning for Gradient Boosting Models

Project description

milgboost

Multiple Instance Learning for Gradient Boosting Models.

MIL is a weakly supervised learning paradigm where labels are available for bags (groups of instances) rather than individual instances. milgboost brings MIL to gradient boosting by wrapping LightGBM and XGBoost with custom differentiable objectives — currently the LogSumExp Binary Cross-Entropy (LSE-BCE) loss, a smooth approximation of the max-instance MIL loss.

Installation

uv add milgboost

Extra options

Install with a specific boosting backend:

uv add milgboost[xgboost-cpu]
uv add milgboost[xgboost] # GPU enabled
uv add milgboost[lightgbm]
uv add milgboost[xgboost-cpu,lightgbm]

Module overview

Module Description
milgboost.types Bag / LabeledBag dataclasses + array↔bag conversion helpers
milgboost.datasets make_mil_data() — synthetic MIL data generator
milgboost.model.base BaseMILModel abstract class (fit / predict / predict_proba)
milgboost.model.xgboost XGBoostMILModel — XGBoost-backed MIL classifier
milgboost.model.lightgbm LightGBMMILModel — LightGBM-backed MIL classifier
milgboost.objective.base BaseMILObjective abstract interface for custom MIL objectives
milgboost.objective.lse LSEBCE — LogSumExp binary cross-entropy objective

Output ordering

All prediction methods (predict, predict_proba, predict_bags, predict_proba_bags) return results sorted by bag_id in ascending order. For example, if your bag IDs are [3, 1, 2], the output will be ordered as bags [1, 2, 3].

Recommendation: Sort both x and z by z values before prediction to ensure output aligns with your expected ordering:

# Sort x and z by z values before prediction
sort_idx = np.argsort(z)
x_sorted, z_sorted = x[sort_idx], z[sort_idx]

# Predictions will follow the sorted order
probs = model.predict_proba(x_sorted, z_sorted)
# probs[i] corresponds to bag i (after sorting)

Using sequential bag IDs (0, 1, 2, ...) is the simplest approach to avoid confusion.

Sample code

import numpy as np
from milgboost.datasets import make_mil_data
from milgboost.objective import LSEBCE
from milgboost.model import LightGBMMILModel

# Generate synthetic MIL data: 200 bags, 10 features
x, y, z = make_mil_data(
    n_bags=200,
    n_features=10,
    n_informative=5,
    key_instance_ratio=0.3,
    random_state=42,
)

# Split into train/test bags
n_train = 150
train_idx = z < n_train
test_idx = z >= n_train

x_train, y_train, z_train = x[train_idx], y[train_idx], z[train_idx]
x_test, y_test, z_test = x[test_idx], y[test_idx], z[test_idx]

# Train LSE-BCE LightGBM MIL model
model = LightGBMMILModel(
    objective=LSEBCE(r=1.0),
    lgb_params={"verbose": -1, "num_leaves": 15},
    num_boost_round=100,
)
model.fit(x_train, y_train, z_train)

# Predict
probs = model.predict_proba(x_test, z_test)
preds = model.predict(x_test, z_test)
print(f"Accuracy: {(preds == y_test[: len(preds)]).mean():.3f}")

Development

git clone <repo>
cd milgboost

# Create virtualenv and install all extras + dev deps
uv sync --all-extras --group dev

# Type check
uv run poe check

# Lint & format
uv run poe lint
uv run poe format

# Run tests
uv run poe test

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

milgboost-0.1.0.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

milgboost-0.1.0-py3-none-any.whl (9.7 kB view details)

Uploaded Python 3

File details

Details for the file milgboost-0.1.0.tar.gz.

File metadata

  • Download URL: milgboost-0.1.0.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.16 {"installer":{"name":"uv","version":"0.11.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for milgboost-0.1.0.tar.gz
Algorithm Hash digest
SHA256 cd4840b001f0cce132c8542223af803890b24bf4b1dbcc66f0098ba95d66aa8b
MD5 ff33670d2e798e4d2aee8aea75cc97fd
BLAKE2b-256 e7c3f9577ede7368420b577ff99d17ea9aa7f921b4363b2e18c410a34ee6eb4f

See more details on using hashes here.

File details

Details for the file milgboost-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: milgboost-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.16 {"installer":{"name":"uv","version":"0.11.16","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"13","id":"trixie","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for milgboost-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7974ead2e5a08641459f44bbc99a299b7de6910cfd7e6e52161b7acd506f1ac7
MD5 9ef5559e10289d55ac3860ce038caa9b
BLAKE2b-256 913892b2b796250a8fc9f88ee68d6eba1e1f7b18b6bce03902925565a00bc644

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page