Skip to main content

A Lightweight Keras API for Training PyTorch Models

Project description

Keras4Torch

"A Lightweight Keras API for Training PyTorch Models❤"

Python GitHub license

Keras4Torch is a subset of Keras in PyTorch. You can use keras4torch.Model to wrap any torch.nn.Module and get the core training features of Keras by using model.fit(), model.evaluate() and model.predict(). Most of the training code in Keras can work in Keras4Torch with little or no change.

Installation

pip install keras4torch

Keras4Torch supports Python 3.6 or newer.

Quick Start

The usage of Keras4Torch is almost the same with Keras.

Let's start with a simple example of MNIST!

import torch
import torchvision
from torch import nn

import keras4torch

(1) Preprocess Data

mnist = torchvision.datasets.MNIST(root='./', download=True)
X, y = mnist.train_data, mnist.train_labels

X = X.float() / 255.0    # scale the pixels to [0, 1]

x_train, y_train = X[:40000], y[:40000]
x_test, y_test = X[40000:], y[40000:]

(2) Define the Model

model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 512), nn.ReLU(),
    nn.Linear(512, 128), nn.ReLU(),
    nn.Linear(128, 10)
)

model = keras4torch.Model(model)    # attention this line

(3) Compile Loss and Metric

model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])

(4) Training

history = model.fit(x_train, y_train,
                	epochs=30,
                	batch_size=512,
                	validation_split=0.2,
                	)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 0.7s - loss: 0.7440 - acc: 0.8149 - val_loss: 0.3069 - val_acc: 0.9114 - lr: 1e-03
Epoch 2/30 - 0.5s - loss: 0.2650 - acc: 0.9241 - val_loss: 0.2378 - val_acc: 0.9331 - lr: 1e-03
Epoch 3/30 - 0.5s - loss: 0.1946 - acc: 0.9435 - val_loss: 0.1940 - val_acc: 0.9431 - lr: 1e-03
Epoch 4/30 - 0.5s - loss: 0.1513 - acc: 0.9555 - val_loss: 0.1663 - val_acc: 0.9524 - lr: 1e-03
... ...

(5) Plot Learning Curve

history.plot(kind='line', y=['acc', 'val_acc'])

(6) Evaluate on Test Set

model.evaluate(x_test, y_test)
OrderedDict([('loss', 0.121063925), ('acc', 0.9736)])

Feature Support

keras4torch torchkeras keras
callbacks x
metrics
numpy dataset x
GPU support
shape inference x x
functional API x x
multi-input x x

Communication

If you have problems when using Keras4Torch, check Github Issues or send email to blueloveTH@foxmail.com.

We also welcome Pull Requests.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras4torch-0.3.0.tar.gz (10.1 kB view hashes)

Uploaded Source

Built Distribution

keras4torch-0.3.0-py3-none-any.whl (11.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page