Skip to main content

A fast Tsetlin Machine impl, based on c++

Project description

Green Tsetlin

logo

Installation

Green Tsetlin can be installed by the following:

pip install green-tsetlin

Tsetlin Machine

The Tsetlin Machine is the core of Green Tsetlin.

import green_tsetlin as gt

tm = gt.TsetlinMachine(n_literals=4,
                       n_clauses=5,
                       n_classes=2,
                       s=3.0,
                       threshold=42,
                       literal_budget=4,
                       boost_true_positives=False,
                       multi_label=False)

Trainer

Green Tsetlin Trainer is a simple wrapper for the Tsetlin Machine.

import green_tsetlin as gt
        
tm = gt.TsetlinMachine(n_literals=4, 
                       n_clauses=5, 
                       n_classes=2, 
                       s=3.0, 
                       threshold=42, 
                       literal_budget=4)        

trainer = gt.Trainer(tm, seed=42, n_jobs=2)

trainer.set_train_data(train_x, train_y)
trainer.set_eval_data(eval_x, eval_y)

trainer.train()

Exporting Tsetlin Machines

Exporting trained Tsetlin Machines.

.
.
tm.save_state("tsetlin_state.npz")

Loading exported Tsetlin Machines

Loading trained Tsetlin Machines to continue training or use for inference.

.
.
tm.load_state("tsetlin_state.npz")

Inference

Inference with trained Tsetlin Machines.

.
.
predictor = tm.get_predictor()
predictor.predict(x)

Green Tsetlin hpsearch

With the built-in hyperparameter search you can optimize your Tsetlin Machine parameters.

from green_tsetlin.hpsearch import HyperparameterSearch

hyperparam_search = HyperparameterSearch(s_space=(2.0, 20.0),
                                        clause_space=(5, 10),
                                        threshold_space=(3, 20),
                                        max_epoch_per_trial=20,
                                        literal_budget=(1, train_x.shape[1]),
                                        seed=42,
                                        n_jobs=5,
                                        k_folds=4,
                                        minimize_literal_budget=False)

hyperparam_search.set_train_data(train_x, train_y)
hyperparam_search.set_eval_data(test_x, test_y)

hyperparam_search.optimize(trials=10)

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page