Skip to main content

scikit-learn like interface and stacked autoencoder for chainer

Project description

scikit-learn like interface and stacked autoencoder for chainer

Requirements

  • numpy

  • scikit-learn

  • chainer >= 1.5

Installation

pip install zChainer

Usage

Autoencoder

import numpy as np
import chainer.functions as F
import chainer.links as L
from chainer import ChainList, optimizers
from zChainer import NNAutoEncoder, utility

data = (..).astype(np.float32)

encoder = ChainList(
    L.Linear(784, 200),
    L.Linear(200, 100))
decoder =ChainList(
    L.Linear(200, 784),
    L.Linear(100, 200))

# You can set your own forward function. Default is as below.
#def forward(self, x):
#    h = F.dropout(F.relu(self.model[0](x)))
#    return F.dropout(F.relu(self.model[1](h)))
#
#NNAutoEncoder.forward = forward
ae = NNAutoEncoder(encoder, decoder, optimizers.Adam(), epoch=100, batch_size=100,
    log_path="./ae_log_"+utility.now()+".csv", export_path="./ae_"+utility.now()+".model")

ae.fit(data)

Training and Testing

import numpy as np
import chainer.functions as F
import chainer.links as L
from chainer import ChainList, optimizers
from zChainer import NNManager, utility
import pickle

X_train = (..).astype(np.float32)
y_train = (..).astype(np.int32)
X_test = (..).astype(np.float32)
y_test = (..).astype(np.int32)

# Create a new network
model = ChainList(L.Linear(784, 200), L.Linear(200, 100), L.Linear(100, 10))

# or load a serialized model
#f = open("./ae_2015-12-01_11-26-45.model")
#model = pickle.load(f)
#f.close()
#model.add_link(L.Linear(100,10))

def forward(self, x):
    h = F.relu(self.model[0](x))
    h = F.relu(self.model[1](h))
    return F.relu(self.model[2](h))

def output(self, y):
    y_trimed = y.data.argmax(axis=1)
    return np.array(y_trimed, dtype=np.int32)

NNManager.forward = forward
NNManager.output = output
nn = NNManager(model, optimizers.Adam(), F.softmax_cross_entropy,
    epoch=100, batch_size=100,
    log_path="./training_log_"+utility.now()+".csv")

nn.fit(X_train, y_train, is_classification=True)
nn.predict(X_test, y_test)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zChainer-0.3.2.tar.gz (3.7 kB view details)

Uploaded Source

File details

Details for the file zChainer-0.3.2.tar.gz.

File metadata

  • Download URL: zChainer-0.3.2.tar.gz
  • Upload date:
  • Size: 3.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for zChainer-0.3.2.tar.gz
Algorithm Hash digest
SHA256 ceebcacb13f5bf000452e0fdb099bd6c706af43ff0c7df2227109ecb32495fa1
MD5 1d5262beadbcbd444333e657a5d7aa77
BLAKE2b-256 a38bc5098e686351297db0d79a834f3d82a2365e6a05beb709068ab07871f9fc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page