Hierarchical extreme multiclass and multi-label classification.
Project description
✨ myriade 🌲
Hierarchical extreme multiclass and multi-label classification.
Motivation
Extreme multiclass classification problems are situations where the number of labels is extremely large. Typically, in the order of tens of thousands of labels. These problems can also be multi-label: a sample can be assigned more than one label. Usual methods don't scale well in these cases.
This Python package provides methods to address multiclass classification. It takes a hierarchical approach. The idea being to organize labels into a binary tree, and train a binary classifier at each node.
🏗️ The package is not prime time ready yet, but the existing code is tested and usable. You can use it to perform multiclass classification, but not multi-label classification. Stay tuned! You can also contribute 🙃
Installation
pip install myriade
Multiclass
Example dataset
A multiclass classification dataset contains a 2D matrix/dataframe of features, and a 1D array/series of labels.
For these examples, we'll load the first 5 digits of the UCI ML hand-written digits dataset.
>>> import myriade
>>> from sklearn import datasets
>>> from sklearn import model_selection
>>> from sklearn import preprocessing
>>> X, y = datasets.load_digits(n_class=5, return_X_y=True)
>>> X = preprocessing.scale(X)
>>> X.shape
(901, 64)
>>> sorted(set(y))
[0, 1, 2, 3, 4]
>>> X_train, X_test, y_train, y_test = model_selection.train_test_split(
... X, y, test_size=0.5, random_state=42
... )
In this case there's only 5 classes, so of course you could just use scikit-learn's OneVsRestClassifier
. The point of this package is to scale to hundreds of thousands of classes, in which case a OneVsRestClassifier
would be way too slow.
Baselines
Random balanced
The most basic strategy is to organize labels into a random hierarchy. The RandomBalancedHierarchyClassifier
does just this, by creating a balanced tree. The randomness is controlled with the seed
parameter.
>>> from sklearn import linear_model
>>> model = myriade.multiclass.RandomBalancedHierarchyClassifier(
... classifier=linear_model.LogisticRegression(),
... seed=42
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
94.01%
You can use the to_graphviz
method of a model's tree_
attribute to obtain a graphviz.Digraph
representation.
>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('random', directory='img', format='svg', cleanup=True)
☝️ Note that the graphviz
library is not installed by default because it requires a platform dependent binary. Therefore, you have to install it by yourself.
Optimal
It's also possible to search the spaces of all possible hierarchies, and pick the best one. Hierarchies are compared with each other by estimating their performance with cross-validation.
>>> from sklearn import metrics
>>> cv = model_selection.ShuffleSplit(
... n_splits=1,
... train_size=0.5,
... random_state=42
... )
>>> model = myriade.multiclass.OptimalHierarchyClassifier(
... classifier=linear_model.LogisticRegression(),
... cv=cv,
... scorer=metrics.make_scorer(metrics.accuracy_score),
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
98.89%
>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('optimal', directory='img', format='svg', cleanup=True)
The only downside to this method is that the amount of possible hierarchies grows extremely large with the number of labels. In fact, if I'm not mistaken, this amount corresponds to sequence A001147 in the Online Encyclopedia of Integer Sequences (OEIS):
Number of labels | Number of possible hierarchies |
---|---|
1 | 1 |
2 | 1 |
3 | 3 |
4 | 15 |
5 | 105 |
6 | 945 |
7 | 10,395 |
8 | 135,135 |
9 | 2,027,025 |
10 | 34,459,425 |
This method is therefore only useful for benchmarking purposes. Indeed, for a small number of labels, it's useful to know if a hierarchy is optimal in some sense.
Manual
You can also specify a hierarchy manually via the myriade.Branch
class.
>>> b = myriade.Branch
>>> tree = b(
... b(0, 1),
... b(
... 2,
... b(3, 4)
... )
... )
>>> dot = tree.to_graphviz()
>>> path = dot.render('manual', directory='img', format='svg', cleanup=True)
>>> model = myriade.multiclass.ManualHierarchyClassifier(
... classifier=linear_model.LogisticRegression(),
... tree=tree
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
94.24%
Balanced
The above methods are baselines: they're either too naïve, or too greedy. A smarter idea is to use some sort of heuristic for building the hierarchy. The BalancedHierarchyClassifier
builds a hierarchy by studying a confusion matrix.
First, a base model produces cross-validated predictions. A confusion matrix is built. The two classes which most confused with each other form a branch. The process is repeated until all classes have been paired together. Next, the confusion matrix is shrinked to that pairs of labels are compared with each other. Then the pairing process is repeated. After roughly log2(k)
steps, a balanced tree is obtained.
>>> base_model = myriade.multiclass.RandomBalancedHierarchyClassifier(
... classifier=linear_model.LogisticRegression(),
... seed=42
... )
>>> cv = model_selection.KFold(
... n_splits=2,
... shuffle=True,
... random_state=42
... )
>>> model = myriade.multiclass.BalancedHierarchyClassifier(
... classifier=linear_model.LogisticRegression(),
... base_model=base_model,
... cv=cv
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
98.45%
>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('balanced', directory='img', format='svg', cleanup=True)
Multi-label
🏗️
Datasets
Name | Function | Size | Samples | Features | Labels | Multi-label | Labels/sample |
---|---|---|---|---|---|---|---|
DMOZ | load_dmoz |
614,8 MB | 394,756 | 833,484 | 36,372 | ✓ | 1.02 |
Wikipedia (small) | load_wiki_small |
135,5 MB | 456,886 | 2,085,165 | 36,504 | ✓ | 1.84 |
Wikipedia (large) | load_wiki_large |
1,01 GB | 2,365,436 | 2,085,167 | 325,056 | ✓ | 3.26 |
Each load_*
function returns two arrays which contain the features and the target classes, respectively. In the multi-label case, the target array is 2D. The arrays are sparse when applicable.
The first time you call a load_*
function, the data will be downloaded and saved into a .svm
file that adheres to the LIBSVM format convention. The loaders will restart from scratch if you interrupt them during their work. You can see where the datasets are stored by calling myriade.datasets.get_data_home
.
All of the datasets are loaded in memory with the svmloader
library. The latter is much faster than the load_svmlight_file
function from scikit-learn. However, when working repeatedly on the same dataset, it is recommended to wrap the dataset loader with joblib.Memory.cache
to store a memmapped backup of the results of the first call. This enables near instantaneous loading for subsequent calls.
Benchmarks
🏗️
Contributing
# Download and navigate to the source code
git clone https://github.com/MaxHalford/myriade
cd myriade
# Install poetry
curl -sSL https://install.python-poetry.org | POETRY_PREVIEW=1 python3 -
# Install in development mode
poetry install --dev
# Run tests
pytest
There's a small roadmap here if you're willing to contribute and looking for ideas 🙏
License
This project is free and open-source software licensed under the MIT license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file myriade-0.2.0.tar.gz
.
File metadata
- Download URL: myriade-0.2.0.tar.gz
- Upload date:
- Size: 16.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.10.8 Darwin/22.2.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 822e63c6c5e68a015bbd2e4f92bde77640649d8f75dcf88fe18a9c01a3fb7e3e |
|
MD5 | 9078937208aef92807754f15f313dba2 |
|
BLAKE2b-256 | e0b6b84dc65a157fabc6af3df870123b3afd2d5f7bfc20995b2fe8e171f48cad |
File details
Details for the file myriade-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: myriade-0.2.0-py3-none-any.whl
- Upload date:
- Size: 14.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.2 CPython/3.10.8 Darwin/22.2.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c556bb5b878d62e1009746dff67130b75855462e57afe4f115af40ba448769c7 |
|
MD5 | 5ec2602759108460dc9bacd29b18b24f |
|
BLAKE2b-256 | be6a57da936d34d6c9e6f6313bd310df49536696db1273c066739d23f7b7e8da |