Skip to main content

A Ready-to-Use Wrapper for Training PyTorch Models✨

Project description

Keras4Torch

A Ready-to-Use Wrapper for Training PyTorch Models✨

PyPI Downloads CodeCov License

DocumentationsDev LogsMini Version

keras4torch provides a high-level API to train PyTorch models compatible with Keras. This project is designed for beginner with these objectives:

  • Help people who are new to PyTorch but familar with Keras
  • Reduce the cost for migrating Keras model implementation to PyTorch

Use keras4torch for Kaggle's code competition! Check this package dataset and starter notebook.


Installation

pip install keras4torch

PyTorch 1.6+ and Python 3.6+ is required.

Quick start

Suppose you have a nn.Module to train.

model = torchvision.models.resnet18(num_classes=10)

All you need to do is wrapping it via k4t.Model().

import keras4torch as k4t

model = k4t.Model(model)

Now, there're two workflows can be used for training.

The NumPy workflow is compatible with Keras.

  • .compile(optimizer, loss, metrics) for settings of optimizer, loss and metrics
  • .fit(x, y, epochs, batch_size, ...) takes raw numpy input for training
  • .evaluate(x, y) outputs a dict result of your metrics
  • .predict(x) for doing predictions

And DataLoader workflow is more flexible and of pytorch style.

  • .compile(optimizer, loss, metrics) same as NumPy workflow
  • .fit_dl(train_loader, val_loader, epochs) for training the model via DataLoader
  • .evaluate_dl(data_loader) same as NumPy workflow but takes DataLoader
  • .predict_dl(data_loader) same as NumPy workflow but takes DataLoader

The two workflows can be mixed.

MNIST example

Here we show a complete example of training a ConvNet on MNIST.

import torch
import torchvision
from torch import nn

import keras4torch as k4t

Step1: Preprocess data

mnist = torchvision.datasets.MNIST(root='./', download=True)
x, y = mnist.train_data.unsqueeze(1), mnist.train_labels

x = x.float() / 255.0    # scale the pixels to [0, 1]

x_train, y_train = x[:40000], y[:40000]
x_test, y_test = x[40000:], y[40000:]

Step2: Define the model

If you have a nn.Module already, just wrap it via k4t.Model. For example,

model = torchvision.models.resnet50(num_classes=10)

model = k4t.Model(model)

For building models from scratch, you can use KerasLayer (located in k4t.layers) for automatic shape inference, which can free you from calculating the input channels.

As is shown below, k4t.layers.Conv2d(32, kernel_size=3) equals nn.Conv2d(?, 32, kernel_size=3) where the first parameter ? (i.e. in_channels) will be determined by itself.

model = torch.nn.Sequential(
    k4t.layers.Conv2d(32, kernel_size=3), nn.ReLU(),
    nn.MaxPool2d(2, 2), 
    k4t.layers.Conv2d(64, kernel_size=3), nn.ReLU(),
    nn.Flatten(),
    k4t.layers.Linear(10)
)

A model containing KerasLayer needs an extra .build(input_shape) operation.

model = k4t.Model(model).build([1, 28, 28])

Step3: Summary the model

model.summary()
=========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
=========================================================================================
├─Conv2d*: 1-1                           [-1, 32, 26, 26]          320
├─ReLU: 1-2                              [-1, 32, 26, 26]          --
├─MaxPool2d: 1-3                         [-1, 32, 13, 13]          --
├─Conv2d*: 1-4                           [-1, 64, 11, 11]          18,496
├─ReLU: 1-5                              [-1, 64, 11, 11]          --
├─Flatten: 1-6                           [-1, 7744]                --
├─Linear*: 1-7                           [-1, 10]                  77,450
=========================================================================================
Total params: 96,266
Trainable params: 96,266
Non-trainable params: 0
Total mult-adds (M): 2.50
=========================================================================================

Step4: Config optimizer, loss and metrics

model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])

If GPU is available, it will be used automatically. You can also pass device parameter to .compile() explicitly.

Step5: Training

history = model.fit(x_train, y_train,
                	epochs=30,
                	batch_size=512,
                	validation_split=0.2,
                	)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 2.8s - loss: 0.6109 - acc: 0.8372 - val_loss: 0.2712 - val_acc: 0.9235 - lr: 1e-03
Epoch 2/30 - 1.5s - loss: 0.2061 - acc: 0.9402 - val_loss: 0.1494 - val_acc: 0.9579 - lr: 1e-03
Epoch 3/30 - 1.5s - loss: 0.1202 - acc: 0.9653 - val_loss: 0.0974 - val_acc: 0.9719 - lr: 1e-03
Epoch 4/30 - 1.5s - loss: 0.0835 - acc: 0.9757 - val_loss: 0.0816 - val_acc: 0.9769 - lr: 1e-03
... ...

Step6: Plot learning curve

history.plot(kind='line', y=['loss', 'val_loss'])

Step7: Evaluate on test set

model.evaluate(x_test, y_test)
{'loss': 0.06655170023441315, 'acc': 0.9839999675750732}

Communication

We have activated Github Discussion for Q&A and most general topics!

For bugs report, please use Github Issues.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras4torch-1.1.4a1.tar.gz (28.3 kB view details)

Uploaded Source

Built Distribution

keras4torch-1.1.4a1-py3-none-any.whl (32.8 kB view details)

Uploaded Python 3

File details

Details for the file keras4torch-1.1.4a1.tar.gz.

File metadata

  • Download URL: keras4torch-1.1.4a1.tar.gz
  • Upload date:
  • Size: 28.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.1.4a1.tar.gz
Algorithm Hash digest
SHA256 72b83879732fc8c3235d46b0db678cdc449153899404e2adaa61f0769064a997
MD5 70325a712514423cd4def1bcb9815f83
BLAKE2b-256 249322951814159617bda8389dab21d5dcff3df3fc5871b904d53d614a0a7717

See more details on using hashes here.

Provenance

File details

Details for the file keras4torch-1.1.4a1-py3-none-any.whl.

File metadata

  • Download URL: keras4torch-1.1.4a1-py3-none-any.whl
  • Upload date:
  • Size: 32.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.1.4a1-py3-none-any.whl
Algorithm Hash digest
SHA256 2da2679521151d4aa8ebee08a436aeae78b4824eebdad680357fb2676b3b2ce6
MD5 05b2f3ca130c05ba0b2812c9c0224ffd
BLAKE2b-256 333c39ccfc9657e5ad0a6cf9c06a7ed750b20cb353847b00e6421ec69ddb02a3

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page