A simple bit of code for training classification neural networks.
Project description
NNClass
A simple bit of code for training classification neural networks.
Installation
Install from pip3
:
pip3 install --user NNClass
Or by cloning this repository:
#clone the repo
git clone https://github.com/mattkjames7/NNClass
cd NNClass
#Either create a wheel and use pip: (X.X.X should be replaced with the current version)
python3 setup.py bdist_wheel
pip3 install --user dists/NNClass-X.X.X-py3-none-any.whl
#Or by using setup.py directly
python3 setup.py install --user
Usage
Start by training training a network:
import NNClass as nnc
#create the network, defining the activation functions and the number of nodes in each layer
net = nnc.NNClass(s,AF='softplus',Output='linear')
#note that s should be a list, where each element denotes the number of nodes in each layer
#input training data
net.AddData(X,y)
#Input matrix X should be of the shape (m,n) - where m is the number of samples and n is the number of input features
#Output hypothesis matrix y should either be
# an array (m,) of integers corresponding to class
# or matrix (m,k) of one-hot labels
#optionally add validation and test data
net.AddValidationData(Xv,yv)
#Note that validation data is ignored if kfolds > 1 during training
net.AddTestData(Xt,yt)
#Train the network
net.Train(nEpoch,kfolds=k)
#nEpoch is the number of training epochs
#kfolds is the number of kfolds to do - if kfolds > 1 then the training data are split
#into kfold sets, each of which has a turn at being the validation set. This results in
#kfold networks being trained in total (net.model)
#see docstring net.Train? to see more options
After training, the cost function may be plotted:
net.PlotCost(k=k)
We can use the network on other data:
#X in this case is a new matrix
y = net.Predict(X)
The networks can be saved and reloaded:
#save
net.Save(fname='networkname.bin')
#reload
net = nnc.LoadANN(fname='networkname.bin')
Running mnist = nnc.Test()
will perform a test on the code, by training a neural network to classify a set of hand-written digits (0-9) from the MNIST dataset (https://deepai.org/dataset/mnist). The function will then plot out the cost, accuracy and an example of a classified digit, e.g.:
The 10,000 sample MNIST data can be accessed using the NNClass.MNIST
object.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file NNClass-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: NNClass-0.0.1-py3-none-any.whl
- Upload date:
- Size: 1.7 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.19.0 setuptools/40.6.2 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9afe6d5c91f71fd240729d4bd9343dd64455cd19f30cc582742718b12b598bc8 |
|
MD5 | 642cf8f68c50e3d216470daec2b11e2a |
|
BLAKE2b-256 | 34bb03aba3ed341be44cad576db33c8247563e7744bc3012e675c8a1d37d89a0 |