Skip to main content

Torcwnn is a Python library for Weightless Neural Network

Project description

Torchwnn: Weightless Neural Network

Torchwnn is a Python library for Weightless Neural Network (also known as RAM-based and N-tuple based Neural Network ).

Usage

Installation

First, install PyTorch using their installation instructions. Then, use the following command to install Torchwnn:

pip install torchwnn

Requirements: PyTorch and ucimlrepo to load datasets from UCI repository.

Quick Start

Iris Example

To quickly get started with Torchwnn, here's an example using the Iris dataset. Full training code is available in the examples/iris.py file.

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from torchwnn.datasets.iris import Iris
from torchwnn.classifiers import Wisard
from torchwnn.encoding import Thermometer

# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

iris = Iris()
X = iris.features
X = torch.tensor(X.values).to(device)
y = list(iris.labels)
y = torch.tensor(y).squeeze().to(device)

bits_encoding = 20
encoding = Thermometer(bits_encoding).fit(X)    
X_bin = encoding.binarize(X).flatten(start_dim=1)

X_train, X_test, y_train, y_test = train_test_split(X_bin, y, test_size=0.3, random_state = 0)  

entry_size = X_train.shape[1]
tuple_size = 8
model = Wisard(entry_size, iris.num_classes, tuple_size)

with torch.no_grad():
    model.fit(X_train,y_train)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard: Accuracy = ", acc)

Examples

There are several examples in the repository.

Bleaching

import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from torchwnn.datasets.iris import Iris
from torchwnn.classifiers import Wisard
from torchwnn.encoding import Thermometer

# Use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

iris = Iris()
X = iris.features
X = torch.tensor(X.values).to(device)
y = list(iris.labels)
y = torch.tensor(y).squeeze().to(device)

bits_encoding = 20
encoding = Thermometer(bits_encoding).fit(X)    
X_bin = encoding.binarize(X).flatten(start_dim=1)

X_train, X_test, y_train, y_test = train_test_split(X_bin, y, test_size=0.3, random_state = 0)  

entry_size = X_train.shape[1]
tuple_size = 8
model = Wisard(entry_size, iris.num_classes, tuple_size, bleaching=True)

with torch.no_grad():
    model.fit(X_train,y_train)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard: Accuracy = ", acc)
    
    # Applying bleaching
    model.fit_bleach(X_train,y_train)
    print("Selected bleach: ", model.bleach)
    predictions = model.predict(X_test)  
    acc = accuracy_score(predictions, y_test)
    print("Wisard with bleaching = ", model.bleach,": Accuracy = ", acc)

BloomWisard

Example using BloomWisard is available in the examples/iris_filters.py file.

Supported WNN models

Currently, the library supports the following WNN models:

Supported techniques:

  • B-bleaching - Bleaching based on binary search. Reference: B-bleaching : Agile Overtraining Avoidance in the WiSARD Weightless Neural Classifier.
    • WiSARD
    • BloomWiSARD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchwnn-0.0.0.tar.gz (12.3 kB view details)

Uploaded Source

File details

Details for the file torchwnn-0.0.0.tar.gz.

File metadata

  • Download URL: torchwnn-0.0.0.tar.gz
  • Upload date:
  • Size: 12.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for torchwnn-0.0.0.tar.gz
Algorithm Hash digest
SHA256 4717291ce7f15d1278b1bfb6366e5dd5f1086ccae0a27b2095e2e4557f9eaf33
MD5 beded90c1b5bc6a4b43cafcf8b2c014a
BLAKE2b-256 36066f7f7ed2338f7a8baf692ff5408e7fe1270a5a35d26383ff71ad734a45c7

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchwnn-0.0.0.tar.gz:

Publisher: torchwnn-publish.yml on leandro-santiago/torchwnn

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page