Skip to main content

Torcwnn is a Python library for Weightless Neural Network

Project description

Torchwnn: Weightless Neural Network

Torchwnn is a Python library for Weightless Neural Network (also known as RAM-based and N-tuple based Neural Network ).

Usage

Installation

First, install PyTorch using their installation instructions. Then, use the following command to install Torchwnn:

pip install torchwnn

Requirements: PyTorch and ucimlrepo to load datasets from UCI repository.

Quick Start

Iris Example

To quickly get started with Torchwnn, here's an example using the Iris dataset. Full training code is available in the examples/iris.py file.

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from torchwnn.datasets.iris import Iris
from torchwnn.classifiers import Wisard
from torchwnn.encoding import Thermometer

# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

iris = Iris()
X = iris.features
X = torch.tensor(X.values).to(device)
y = list(iris.labels)
y = torch.tensor(y).squeeze().to(device)

bits_encoding = 20
encoding = Thermometer(bits_encoding).fit(X)    
X_bin = encoding.binarize(X).flatten(start_dim=1)

X_train, X_test, y_train, y_test = train_test_split(X_bin, y, test_size=0.3, random_state = 0)  

entry_size = X_train.shape[1]
tuple_size = 8
model = Wisard(entry_size, iris.num_classes, tuple_size)

with torch.no_grad():
    model.fit(X_train,y_train)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard: Accuracy = ", acc)

Examples

There are several examples in the repository.

Bleaching

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from torchwnn.datasets.iris import Iris
from torchwnn.classifiers import Wisard
from torchwnn.encoding import Thermometer

# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

iris = Iris()
X = iris.features
X = torch.tensor(X.values).to(device)
y = list(iris.labels)
y = torch.tensor(y).squeeze().to(device)

bits_encoding = 20
encoding = Thermometer(bits_encoding).fit(X)    
X_bin = encoding.binarize(X).flatten(start_dim=1)

X_train, X_test, y_train, y_test = train_test_split(X_bin, y, test_size=0.3, random_state = 0)  

entry_size = X_train.shape[1]
tuple_size = 8
model = Wisard(entry_size, iris.num_classes, tuple_size, bleaching=True)

with torch.no_grad():
    model.fit(X_train,y_train)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard: Accuracy = ", acc)
    
    # Applying bleaching
    model.fit_bleach(X_train,y_train)
    print("Selected bleach: ", model.bleach)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard with bleaching = ", model.bleach,": Accuracy = ", acc)

BloomWisard

Example using BloomWisard is available in the examples/iris_filters.py file.

Supported WNN models

Currently, the library supports the following WNN models:

Supported techniques:

  • B-bleaching - Bleaching based on binary search. Reference: B-bleaching : Agile Overtraining Avoidance in the WiSARD Weightless Neural Classifier.
    • WiSARD
    • BloomWiSARD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchwnn-0.0.1a0.tar.gz (22.4 kB view details)

Uploaded Source

File details

Details for the file torchwnn-0.0.1a0.tar.gz.

File metadata

  • Download URL: torchwnn-0.0.1a0.tar.gz
  • Upload date:
  • Size: 22.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for torchwnn-0.0.1a0.tar.gz
Algorithm Hash digest
SHA256 230551129bbfd071063d9b847846bef3a38dc2d8f0cba8f95b2be3b1aa483ea5
MD5 afe699f35dba056a4daa13d3e5b7e954
BLAKE2b-256 67df233c9145cdc1f06bdeee3bf902639d4bbaa39d358396905665904b604767

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchwnn-0.0.1a0.tar.gz:

Publisher: torchwnn-publish.yml on leandro-santiago/torchwnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page