Skip to main content

Conditional Tree Bayesian Network for multi-label classification

Project description

Conditional Tree Bayesian Network

This package can be used as a library to fit a Conditional Tree Bayesian Network (CTBN)[1].

The main module is the CTBN class.

  1. Use the fit() method to fit the CTBN to your multi-label classification data [1]. This method generates an optimal CTBN which is an instance of DirectedGraph using the Chu-Liu-Edmond's algorithm for finding a maximum spanning arborescence [2].
  2. The predict() method returns the most likely assignment to the class variables along with the probability of the assignment. The predict method uses the junction tree algorithm[3] to run the most likely explanation(MLE) query.

Usage

An example of using the package can be found in the jupyter notebook here.

# Import the CTBN class and assuming you have a dataset X_train and Y_train which are numpy arrays.
from ctbn import CTBN
model = CTBN()
model.fit(X_train, Y_train)

#Calling the fit method will generate an optimal CTBN graph of type
#DirectedGraph defined in src/graph_preliminaries.py

#Get predictions and the probability of a prediction on a single sample
#using the predict method. This method will in turn call the junction tree
#algorithm to run the max-sum algorithm on a test_sample.

max_log_prob, max_assignment = model.predict(test_sample)

References

[1] Batal, Iyad and Hong, Charmgil and Hauskrecht, Milos (2013). An Efficient Probabilistic Framework for Multi-Dimensional Classification. In Proceedings of the 22nd ACM International Conference on Information amp; Knowledge Management, CIKM ’13, New York, NY,USA, pp. 2417–2422. Association for Computing Machinery. https://doi.org/10.1145/2505515.2505594

[2] Chu, Y. J. and T. H. Liu, "On the Shortest Arborescence of a Directed Graph," Sci. Sinica, 14, 1965, pp. 1396-1400.

[3] Koller, D., & Friedman, N. (2009). Probabilistic Graphical Models: Principles and Techniques. MIT Press.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

CTBN-0.0.1.tar.gz (9.8 kB view hashes)

Uploaded Source

Built Distribution

CTBN-0.0.1-py3-none-any.whl (10.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page