Skip to main content

Hierarchical extreme multiclass and multi-label classification.

Project description

✨ myriade 🌲

Hierarchical extreme multiclass and multi-label classification.



https://www.reddit.com/r/drawing/comments/84j0gn/night_sky_9x9in_pen/

Motivation

Extreme multiclass classification problems are situations where the number of labels is extremely large. Typically, in the order of tens of thousands of labels. These problems can also be multi-label: a sample can be assigned more than one label. Usual methods don't scale well in these cases.

This Python package provides methods to address multiclass classification. It takes a hierarchical approach. The idea being to organize labels into a binary tree, and train a binary classifier at each node.

🏗️ The package is not prime time ready yet, but the existing code is tested and usable. You can use it to perform multiclass classification, but not multi-label classification. Stay tuned! You can also contribute 🙃

Installation

pip install myriade

Multiclass

Example dataset

A multiclass classification dataset contains a 2D matrix/dataframe of features, and a 1D array/series of labels.

For these examples, we'll load the first 5 digits of the UCI ML hand-written digits dataset.

>>> import myriade
>>> from sklearn import datasets
>>> from sklearn import model_selection
>>> from sklearn import preprocessing

>>> X, y = datasets.load_digits(n_class=5, return_X_y=True)
>>> X = preprocessing.scale(X)
>>> X.shape
(901, 64)

>>> sorted(set(y))
[0, 1, 2, 3, 4]

>>> X_train, X_test, y_train, y_test = model_selection.train_test_split(
...     X, y, test_size=0.5, random_state=42
... )

In this case there's only 5 classes, so of course you could just use scikit-learn's OneVsRestClassifier. The point of this package is to scale to hundreds of thousands of classes, in which case a OneVsRestClassifier would be way too slow.

Baselines

Random balanced

The most basic strategy is to organize labels into a random hierarchy. The RandomBalancedHierarchyClassifier does just this, by creating a balanced tree. The randomness is controlled with the seed parameter.

>>> from sklearn import linear_model

>>> model = myriade.multiclass.RandomBalancedHierarchyClassifier(
...     classifier=linear_model.LogisticRegression(),
...     seed=42
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
94.01%

You can use the to_graphviz method of a model's tree_ attribute to obtain a graphviz.Digraph representation.

>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('random', directory='img', format='svg', cleanup=True)


☝️ Note that the graphviz library is not installed by default because it requires a platform dependent binary. Therefore, you have to install it by yourself.

Optimal

It's also possible to search the spaces of all possible hierarchies, and pick the best one. Hierarchies are compared with each other by estimating their performance with cross-validation.

>>> from sklearn import metrics

>>> cv = model_selection.ShuffleSplit(
...     n_splits=1,
...     train_size=0.5,
...     random_state=42
... )
>>> model = myriade.multiclass.OptimalHierarchyClassifier(
...     classifier=linear_model.LogisticRegression(),
...     cv=cv,
...     scorer=metrics.make_scorer(metrics.accuracy_score),
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
98.89%
>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('optimal', directory='img', format='svg', cleanup=True)


The only downside to this method is that the amount of possible hierarchies grows extremely large with the number of labels. In fact, if I'm not mistaken, this amount corresponds to sequence A001147 in the Online Encyclopedia of Integer Sequences (OEIS):

Number of labels Number of possible hierarchies
1 1
2 1
3 3
4 15
5 105
6 945
7 10,395
8 135,135
9 2,027,025
10 34,459,425

This method is therefore only useful for benchmarking purposes. Indeed, for a small number of labels, it's useful to know if a hierarchy is optimal in some sense.

Manual

You can also specify a hierarchy manually via the myriade.Branch class.

>>> b = myriade.Branch
>>> tree = b(
...     b(0, 1),
...     b(
...         2,
...         b(3, 4)
...     )
... )
>>> dot = tree.to_graphviz()
>>> path = dot.render('manual', directory='img', format='svg', cleanup=True)


>>> model = myriade.multiclass.ManualHierarchyClassifier(
...     classifier=linear_model.LogisticRegression(),
...     tree=tree
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
94.24%

Balanced

The above methods are baselines: they're either too naïve, or too greedy. A smarter idea is to use some sort of heuristic for building the hierarchy. The BalancedHierarchyClassifier builds a hierarchy by studying a confusion matrix.

First, a base model produces cross-validated predictions. A confusion matrix is built. The two classes which most confused with each other form a branch. The process is repeated until all classes have been paired together. Next, the confusion matrix is shrinked to that pairs of labels are compared with each other. Then the pairing process is repeated. After roughly log2(k) steps, a balanced tree is obtained.

>>> base_model = myriade.multiclass.RandomBalancedHierarchyClassifier(
...     classifier=linear_model.LogisticRegression(),
...     seed=42
... )
>>> cv = model_selection.KFold(
...     n_splits=2,
...     shuffle=True,
...     random_state=42
... )
>>> model = myriade.multiclass.BalancedHierarchyClassifier(
...     classifier=linear_model.LogisticRegression(),
...     base_model=base_model,
...     cv=cv
... )
>>> model = model.fit(X_train, y_train)
>>> print(f"{model.score(X_test, y_test):.2%}")
98.45%
>>> dot = model.tree_.to_graphviz()
>>> path = dot.render('balanced', directory='img', format='svg', cleanup=True)


Multi-label

🏗️

Datasets

Name Function Size Samples Features Labels Multi-label Labels/sample
DMOZ load_dmoz 614,8 MB 394,756 833,484 36,372 1.02
Wikipedia (small) load_wiki_small 135,5 MB 456,886 2,085,165 36,504 1.84
Wikipedia (large) load_wiki_large 1,01 GB 2,365,436 2,085,167 325,056 3.26

Each load_* function returns two arrays which contain the features and the target classes, respectively. In the multi-label case, the target array is 2D. The arrays are sparse when applicable.

The first time you call a load_* function, the data will be downloaded and saved into a .svm file that adheres to the LIBSVM format convention. The loaders will restart from scratch if you interrupt them during their work. You can see where the datasets are stored by calling myriade.datasets.get_data_home.

All of the datasets are loaded in memory with the svmloader library. The latter is much faster than the load_svmlight_file function from scikit-learn. However, when working repeatedly on the same dataset, it is recommended to wrap the dataset loader with joblib.Memory.cache to store a memmapped backup of the results of the first call. This enables near instantaneous loading for subsequent calls.

Benchmarks

🏗️

Contributing

# Download and navigate to the source code
git clone https://github.com/MaxHalford/myriade
cd myriade

# Install poetry
curl -sSL https://install.python-poetry.org | POETRY_PREVIEW=1 python3 -

# Install in development mode
poetry install --dev

# Run tests
pytest

There's a small roadmap here if you're willing to contribute and looking for ideas 🙏

License

This project is free and open-source software licensed under the MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

myriade-0.2.0.tar.gz (16.8 kB view details)

Uploaded Source

Built Distribution

myriade-0.2.0-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file myriade-0.2.0.tar.gz.

File metadata

  • Download URL: myriade-0.2.0.tar.gz
  • Upload date:
  • Size: 16.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.10.8 Darwin/22.2.0

File hashes

Hashes for myriade-0.2.0.tar.gz
Algorithm Hash digest
SHA256 822e63c6c5e68a015bbd2e4f92bde77640649d8f75dcf88fe18a9c01a3fb7e3e
MD5 9078937208aef92807754f15f313dba2
BLAKE2b-256 e0b6b84dc65a157fabc6af3df870123b3afd2d5f7bfc20995b2fe8e171f48cad

See more details on using hashes here.

File details

Details for the file myriade-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: myriade-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.2 CPython/3.10.8 Darwin/22.2.0

File hashes

Hashes for myriade-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c556bb5b878d62e1009746dff67130b75855462e57afe4f115af40ba448769c7
MD5 5ec2602759108460dc9bacd29b18b24f
BLAKE2b-256 be6a57da936d34d6c9e6f6313bd310df49536696db1273c066739d23f7b7e8da

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page