Skip to main content

A Lightweight Keras API for Training PyTorch Models

Project description

Keras4Torch

"A Lightweight Keras API for Training PyTorch Models❤"

Python GitHub license

Keras4Torch is a subset of Keras in PyTorch. You can use keras4torch.Model to wrap any torch.nn.Module and get the core training features of Keras by using model.fit(), model.evaluate() and model.predict(). Most of the training code in Keras can work in Keras4Torch with little or no change.

Installation

pip install keras4torch

Keras4Torch supports Python 3.6 or newer.

Quick Start

The usage of Keras4Torch is almost the same with Keras.

Let's start with a simple example of MNIST!

import torch
import torchvision
from torch import nn

import keras4torch

(1) Preprocess Data

mnist = torchvision.datasets.MNIST(root='./', download=True)
X, y = mnist.train_data, mnist.train_labels

X = X.float() / 255.0    # scale the pixels to [0, 1]

x_train, y_train = X[:40000], y[:40000]
x_test, y_test = X[40000:], y[40000:]

(2) Define the Model

model = torch.nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 512), nn.ReLU(),
    nn.Linear(512, 128), nn.ReLU(),
    nn.Linear(128, 10)
)

model = keras4torch.Model(model)    # attention this line

(3) Compile Loss and Metric

model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])

(4) Training

history = model.fit(x_train, y_train,
                	epochs=30,
                	batch_size=512,
                	validation_split=0.2,
                	)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 0.7s - loss: 0.7440 - acc: 0.8149 - val_loss: 0.3069 - val_acc: 0.9114 - lr: 1e-03
Epoch 2/30 - 0.5s - loss: 0.2650 - acc: 0.9241 - val_loss: 0.2378 - val_acc: 0.9331 - lr: 1e-03
Epoch 3/30 - 0.5s - loss: 0.1946 - acc: 0.9435 - val_loss: 0.1940 - val_acc: 0.9431 - lr: 1e-03
Epoch 4/30 - 0.5s - loss: 0.1513 - acc: 0.9555 - val_loss: 0.1663 - val_acc: 0.9524 - lr: 1e-03
... ...

(5) Plot Learning Curve

history.plot(kind='line', y=['acc', 'val_acc'])

(6) Evaluate on Test Set

model.evaluate(x_test, y_test)
OrderedDict([('loss', 0.121063925), ('acc', 0.9736)])

Feature Support

keras4torch torchkeras keras
callbacks x
metrics
numpy dataset x
GPU support
shape inference x x
functional API x x
multi-input x x

Communication

If you have problems when using Keras4Torch, check Github Issues or send email to blueloveTH@foxmail.com.

We also welcome Pull Requests.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras4torch-0.3.0.tar.gz (10.1 kB view details)

Uploaded Source

Built Distribution

keras4torch-0.3.0-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file keras4torch-0.3.0.tar.gz.

File metadata

  • Download URL: keras4torch-0.3.0.tar.gz
  • Upload date:
  • Size: 10.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8

File hashes

Hashes for keras4torch-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1f4077dde147bb47eacf83f87070af255f2e021e025e744c8fab7700c283d807
MD5 031abf8e393e244a3ec98b2c7689de77
BLAKE2b-256 7b30c49736988d7518a9312a4b2a4d98df703801c9fe44122cb2598a3f132e6c

See more details on using hashes here.

Provenance

File details

Details for the file keras4torch-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: keras4torch-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 11.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8

File hashes

Hashes for keras4torch-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 33ffb4b8147c374f7b4296a341a75b3122761dee42ac6081c9fef14bed803f48
MD5 44fce718d879633bc830e571bbd33fb1
BLAKE2b-256 ebb7497c3d4dcbd71ffdab4dc8c3c8c3a91a921824fe1fb1dc25f9106ab14820

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page