scikit-learn like interface and stacked autoencoder for chainer
Project description
scikit-learn like interface and stacked autoencoder for chainer
Requirements
numpy
scikit-learn
chainer >= 1.5
Installation
pip install zChainer
Usage
Autoencoder
import numpy as np
import chainer.functions as F
import chainer.links as L
from chainer import ChainList, optimizers
from zChainer import NNAutoEncoder, utility
data = (..).astype(np.float32)
encoder = ChainList(
L.Linear(784, 200),
L.Linear(200, 100))
decoder =ChainList(
L.Linear(200, 784),
L.Linear(100, 200))
# You can set your own forward function. Default is as below.
#def forward(self, x):
# h = F.dropout(F.relu(self.model[0](x)))
# return F.dropout(F.relu(self.model[1](h)))
#
#NNAutoEncoder.forward = forward
ae = NNAutoEncoder(encoder, decoder, optimizers.Adam(), epoch=100, batch_size=100,
log_path="./ae_"+utility.now()+"_log.csv", export_path="./ae_"+utility.now()+".model")
ae.fit(data)
Training and Testing
import numpy as np
import chainer.functions as F
import chainer.links as L
from chainer import ChainList, optimizers
from zChainer import NNManager, utility
import pickle
X_train = (..).astype(np.float32)
y_train = (..).astype(np.int32)
X_test = (..).astype(np.float32)
y_test = (..).astype(np.int32)
# Create a new network
model = ChainList(L.Linear(784, 200), L.Linear(200, 100), L.Linear(100, 10))
# or load a serialized model
#f = open("./ae_2015-12-01_11-26-45.model")
#model = pickle.load(f)
#f.close()
#model.add_link(L.Linear(100,10))
def forward(self, x):
h = F.relu(self.model[0](x))
h = F.relu(self.model[1](h))
return F.relu(self.model[2](h))
NNManager.forward = forward
nn = NNManager(model, optimizers.Adam(), F.softmax_cross_entropy, epoch=100, batch_size=100,
log_path="./training_"+utility.now()+"_log.csv", export_path="./training_"+utility.now()+".model")
nn.fit(X_train, y_train)
nn.predict(X_test, y_test)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
zChainer-0.1.4.tar.gz
(3.6 kB
view hashes)