A compatible-with-keras wrapper for training PyTorch models✨
Project description
A compatible-with-keras wrapper for training PyTorch models✨
Documentations • Dev Logs • Mini Version
keras4torch
provides a high-level API to train PyTorch models compatible with Keras. This project is designed for beginner with these objectives:
- Help people who are new to PyTorch but familar with Keras
- Reduce the cost for migrating Keras model implementation to PyTorch
Use keras4torch
for Kaggle's code competition! Check this package dataset and starter notebook.
Installation
pip install keras4torch
PyTorch 1.6+ and Python 3.6+ is required.
Quick start
Suppose you have a nn.Module
to train.
model = torchvision.models.resnet18(num_classes=10)
All you need to do is wrapping it via k4t.Model()
.
import keras4torch as k4t
model = k4t.Model(model)
Now, there're two workflows can be used for training.
The NumPy workflow is compatible with Keras.
.compile(optimizer, loss, metrics)
for settings of optimizer, loss and metrics.fit(x, y, epochs, batch_size, ...)
takes raw numpy input for training.evaluate(x, y)
outputs adict
result of your metrics.predict(x)
for doing predictions
And DataLoader workflow is more flexible and of pytorch style.
.compile(optimizer, loss, metrics)
same as NumPy workflow.fit_dl(train_loader, val_loader, epochs)
for training the model viaDataLoader
.evaluate_dl(data_loader)
same as NumPy workflow but takesDataLoader
.predict_dl(data_loader)
same as NumPy workflow but takesDataLoader
The two workflows can be mixed.
MNIST example
Here we show a complete example of training a ConvNet on MNIST.
import torch
import torchvision
from torch import nn
import keras4torch as k4t
Step1: Preprocess data
mnist = torchvision.datasets.MNIST(root='./', download=True)
x, y = mnist.train_data.unsqueeze(1), mnist.train_labels
x = x.float() / 255.0 # scale the pixels to [0, 1]
x_train, y_train = x[:40000], y[:40000]
x_test, y_test = x[40000:], y[40000:]
Step2: Define the model
If you have a nn.Module
already, just wrap it via k4t.Model
. For example,
model = torchvision.models.resnet50(num_classes=10)
model = k4t.Model(model)
For building models from scratch, you can use KerasLayer
(located in k4t.layers
) for automatic shape inference, which can free you from calculating the input channels.
As is shown below, k4t.layers.Conv2d(32, kernel_size=3)
equals nn.Conv2d(?, 32, kernel_size=3)
where the first parameter ?
(i.e. in_channels
) will be determined by itself.
model = torch.nn.Sequential(
k4t.layers.Conv2d(32, kernel_size=3), nn.ReLU(),
nn.MaxPool2d(2, 2),
k4t.layers.Conv2d(64, kernel_size=3), nn.ReLU(),
nn.Flatten(),
k4t.layers.Linear(10)
)
A model containing KerasLayer
needs an extra .build(input_shape)
operation.
model = k4t.Model(model).build([1, 28, 28])
Step3: Summary the model
model.summary()
=========================================================================================
Layer (type:depth-idx) Output Shape Param #
=========================================================================================
├─Conv2d*: 1-1 [-1, 32, 26, 26] 320
├─ReLU: 1-2 [-1, 32, 26, 26] --
├─MaxPool2d: 1-3 [-1, 32, 13, 13] --
├─Conv2d*: 1-4 [-1, 64, 11, 11] 18,496
├─ReLU: 1-5 [-1, 64, 11, 11] --
├─Flatten: 1-6 [-1, 7744] --
├─Linear*: 1-7 [-1, 10] 77,450
=========================================================================================
Total params: 96,266
Trainable params: 96,266
Non-trainable params: 0
Total mult-adds (M): 2.50
=========================================================================================
Step4: Config optimizer, loss and metrics
model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])
If GPU is available, it will be used automatically. You can also pass device
parameter to .compile()
explicitly.
Step5: Training
history = model.fit(x_train, y_train,
epochs=30,
batch_size=512,
validation_split=0.2,
)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 2.8s - loss: 0.6109 - acc: 0.8372 - val_loss: 0.2712 - val_acc: 0.9235 - lr: 1e-03
Epoch 2/30 - 1.5s - loss: 0.2061 - acc: 0.9402 - val_loss: 0.1494 - val_acc: 0.9579 - lr: 1e-03
Epoch 3/30 - 1.5s - loss: 0.1202 - acc: 0.9653 - val_loss: 0.0974 - val_acc: 0.9719 - lr: 1e-03
Epoch 4/30 - 1.5s - loss: 0.0835 - acc: 0.9757 - val_loss: 0.0816 - val_acc: 0.9769 - lr: 1e-03
... ...
Step6: Plot learning curve
history.plot(kind='line', y=['loss', 'val_loss'])
Step7: Evaluate on test set
model.evaluate(x_test, y_test)
{'loss': 0.06655170023441315, 'acc': 0.9839999675750732}
Communication
We have activated Github Discussion for Q&A and most general topics!
For bugs report, please use Github Issues.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file keras4torch-1.2.4.tar.gz
.
File metadata
- Download URL: keras4torch-1.2.4.tar.gz
- Upload date:
- Size: 28.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c35a38d47b8413b02cc6c10c60629caa87adbb9fcf35051960291aeeed0b2f19 |
|
MD5 | c6d7426faaefa76c4c004ea265dc4898 |
|
BLAKE2b-256 | fcbb69a84ee196b8b83b428c8d20f378df4083b9634aa82d904205e1627fe8a8 |
File details
Details for the file keras4torch-1.2.4-py3-none-any.whl
.
File metadata
- Download URL: keras4torch-1.2.4-py3-none-any.whl
- Upload date:
- Size: 33.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 532d6e27338ac0ff25997e17d8f0f5a3e4a16ce37f77c6a4b214bcc81563caab |
|
MD5 | 1180fd397f9fd85906581f11ca891707 |
|
BLAKE2b-256 | 98cf2c499c553d6af25c5009a6e481eff61c1a235dbb2ddabd7b1a95020de9c1 |