Skip to main content

A low-code Pythonic implementation of a Coalesced Tsetlin Machine

Project description

PyTsetlin

A low-code, feature-POOR, Pythonic implementation of a Coalesced Tsetlin Machine. This is not intended to be a feature-rich or speed-optimized implementation; see relevant repositories like TMU and green-tsetlin for that. However, it's intended to be an easy-to-use TM programmed in Python, with the intent of making it accessible to plug-and-play new ideas and be able to get some results, either on an input level or TM memory level. Also, since the implementation is written entirely in Python, the code can be compared with the theoretical concepts presented in the papers, potentially making it easier to grasp.

Even though this repo is not focused on speed, I have made some functions compatible for Numba compilation. Without this, the code would be so slow that it deems the implementation unusable.

Installation

  1. Install package to environment to use in other projects:
pip install pytsetlin
  1. Clone or template this repository and install the required dependencies:
cd pytsetlin
pip install -r requirements.txt

Examples

Basic training example

Here's a basic example of how to use the Tsetlin Machine:

>>> from pytsetlin import TsetlinMachine
>>> from pytsetlin.data.mnist import get_mnist

>>> X_train, X_test, y_train, y_test = get_mnist()

>>> tm = TsetlinMachine(n_clauses=500,
                        threshold=625,
                        s=10.0,
                        n_threads=20)

>>> tm.set_train_data(X_train, y_train)

>>> tm.set_eval_data(X_test, y_test)

>>> r = tm.train(training_epochs=10)

# progress bar for visualization
Train Acc: 95.78, Eval Acc: 96.22, Best Eval Acc: 96.22 (10): 100%|███████████| 10/10 [00:55<00:00,  5.60s/it]

>>> print(r)
{'train_time': [10.82, 5.95, 5.08, 4.9, 4.65, 4.58, 4.44, 4.38, 4.35, 4.25], 'train_acc': [86.81, 92.18, 93.47, 94.04, 94.54, 94.91, 95.22, 95.53, 95.58, 95.78], 'eval_acc': [91.06, 93.01, 93.62, 94.3, 94.44, 94.73, 94.82, 94.97, 95.22, 96.22], 'best_eval_acc': 96.22, 'best_eval_epoch': 10}

Note performance may vary depending on system!

Investigating TM structure

Since the code is Pythonic, the TM structure can easily be investigated from the TsetlinMachine object:

>>> # xor gate
>>> x = np.array([[0, 0],
                  [0, 1],
                  [1, 0],
                  [1, 1]])

>>> y = np.array([0, 1, 1, 0])

>>> tm = TsetlinMachine(n_clauses=4)

>>> tm.set_train_data(x, y)

>>> tm.train()

>>> print(tm.C) # get clause matrix
[[-35  25  24 -30]
 [-33 -41  12  23]
 [ 18 -38 -34  16]
 [ 17  15 -33 -42]]

>>> print(tm.W) # get weight matrix
[[-19  17 -20  16]
 [ 18 -19  18 -18]]

Saving and loading

Any TM state can easly be saved during of after training

>>> from pytsetlin import TsetlinMachine
>>> from pytsetlin.data.imdb import get_imdb

>>> X_train, X_test, y_train, y_test = get_imdb()

>>> tm = TsetlinMachine(n_clauses=500,
                        threshold=625,
                        s=2.0)

>>> tm.set_train_data(X_train, y_train)

>>> tm.set_eval_data(X_test, y_test)

>>> r = tm.train(training_epochs=10, save_best_state=True) # save during training

>>> tm.save_state(file_name='tm_state.npz') # save after training

Then saved memory, or any memory, can be used for predictions after:

>>> tm = TsetlinMachine()

>>> state = np.load('tm.state.npz')

>>> C = state['C'] # load clause matrix
>>> W = state['W'] # load weight matrix 


>>> clause_outputs = tm.evaluate_clauses(instance, memory=C) # what clauses matched the input
[0, 1, 0, 0, 1]

>>> class_sums = np.dot(W, clause_outputs) # majority voting
[-32, 55]

>>> prediction = np.argmax(class_sums)
1

Literature References

Notes

  1. Input data must be binary (dtype=np.uint8 for features, np.uint32 for labels)
  2. The implementation uses Numba for efficient computation
  3. Memory is allocated automatically when training begins

License

MIT Licence

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytsetlin-1.1.1.tar.gz (13.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytsetlin-1.1.1-py3-none-any.whl (13.9 kB view details)

Uploaded Python 3

File details

Details for the file pytsetlin-1.1.1.tar.gz.

File metadata

  • Download URL: pytsetlin-1.1.1.tar.gz
  • Upload date:
  • Size: 13.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for pytsetlin-1.1.1.tar.gz
Algorithm Hash digest
SHA256 ff35b020524afc6a37c246c2f483a047435a4a8325d360abbd97fbd073deeb81
MD5 217e43f206acb033b11f38dc96eec02c
BLAKE2b-256 291b0f6749a908898f5049bd3f112403a983eda0c77f91df8b8f69f73d456f0a

See more details on using hashes here.

File details

Details for the file pytsetlin-1.1.1-py3-none-any.whl.

File metadata

  • Download URL: pytsetlin-1.1.1-py3-none-any.whl
  • Upload date:
  • Size: 13.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for pytsetlin-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5729e93e0d571f63cc66eb831597136dff4c389cdcec6d15553e5762f419678c
MD5 07483476390a561cbf0088f0e5180f38
BLAKE2b-256 9a74bf7ba4fcd8886a3218cebb95869f1e417a2f511327214968d850d7ed62cf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page