A Lightweight Keras API for Training PyTorch Models
Project description
Keras4Torch
"A Lightweight Keras API for Training PyTorch Models❤"
Keras4Torch is a subset of Keras in PyTorch. You can use keras4torch.Model
to wrap any torch.nn.Module
and get the core training features of Keras by using model.fit()
, model.evaluate()
and model.predict()
. Most of the training code in Keras can work in Keras4Torch with little or no change.
Installation
pip install keras4torch
Keras4Torch supports Python 3.6 or newer.
Quick Start
The usage of Keras4Torch is almost the same with Keras.
Let's start with a simple example of MNIST!
import torch
import torchvision
from torch import nn
import keras4torch
(1) Preprocess Data
mnist = torchvision.datasets.MNIST(root='./', download=True)
X, y = mnist.train_data, mnist.train_labels
X = X.float() / 255.0 # scale the pixels to [0, 1]
x_train, y_train = X[:40000], y[:40000]
x_test, y_test = X[40000:], y[40000:]
(2) Define the Model
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 512), nn.ReLU(),
nn.Linear(512, 128), nn.ReLU(),
nn.Linear(128, 10)
)
model = keras4torch.Model(model) # attention this line
(3) Compile Loss and Metric
model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])
(4) Training
history = model.fit(x_train, y_train,
epochs=30,
batch_size=512,
validation_split=0.2,
)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 0.7s - loss: 0.7440 - acc: 0.8149 - val_loss: 0.3069 - val_acc: 0.9114 - lr: 1e-03
Epoch 2/30 - 0.5s - loss: 0.2650 - acc: 0.9241 - val_loss: 0.2378 - val_acc: 0.9331 - lr: 1e-03
Epoch 3/30 - 0.5s - loss: 0.1946 - acc: 0.9435 - val_loss: 0.1940 - val_acc: 0.9431 - lr: 1e-03
Epoch 4/30 - 0.5s - loss: 0.1513 - acc: 0.9555 - val_loss: 0.1663 - val_acc: 0.9524 - lr: 1e-03
... ...
(5) Plot Learning Curve
history.plot(kind='line', y=['acc', 'val_acc'])
(6) Evaluate on Test Set
model.evaluate(x_test, y_test)
OrderedDict([('loss', 0.121063925), ('acc', 0.9736)])
Feature Support
keras4torch | torchkeras | keras | |
---|---|---|---|
callbacks | √ | x | √ |
metrics | √ | √ | √ |
numpy dataset | √ | x | √ |
GPU support | √ | √ | √ |
shape inference | x | x | √ |
functional API | x | x | √ |
multi-input | x | x | √ |
Communication
If you have problems when using Keras4Torch, check Github Issues or send email to blueloveTH@foxmail.com.
We also welcome Pull Requests.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file keras4torch-0.3.0.tar.gz
.
File metadata
- Download URL: keras4torch-0.3.0.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1f4077dde147bb47eacf83f87070af255f2e021e025e744c8fab7700c283d807 |
|
MD5 | 031abf8e393e244a3ec98b2c7689de77 |
|
BLAKE2b-256 | 7b30c49736988d7518a9312a4b2a4d98df703801c9fe44122cb2598a3f132e6c |
Provenance
File details
Details for the file keras4torch-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: keras4torch-0.3.0-py3-none-any.whl
- Upload date:
- Size: 11.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 33ffb4b8147c374f7b4296a341a75b3122761dee42ac6081c9fef14bed803f48 |
|
MD5 | 44fce718d879633bc830e571bbd33fb1 |
|
BLAKE2b-256 | ebb7497c3d4dcbd71ffdab4dc8c3c8c3a91a921824fe1fb1dc25f9106ab14820 |