Skip to main content

A Ready-to-Use Wrapper for Training PyTorch Models✨

Project description

Keras4Torch

A Ready-to-Use Wrapper for Training PyTorch Models✨

PyPI Downloads CodeCov License

DocumentationsDev LogsMini Version

keras4torch provides a high-level API to train PyTorch models compatible with Keras. This project is designed for beginner with these objectives:

  • Help people who are new to PyTorch but familar with Keras
  • Reduce the cost for migrating Keras model implementation to PyTorch

Use keras4torch for Kaggle's code competition! Check this package dataset and starter notebook.


Installation

pip install keras4torch

PyTorch 1.6+ and Python 3.6+ is required.

Quick start

Suppose you have a nn.Module to train.

model = torchvision.models.resnet18(num_classes=10)

All you need to do is wrapping it via k4t.Model().

import keras4torch as k4t

model = k4t.Model(model)

Now, there're two workflows can be used for training.

The NumPy workflow is compatible with Keras.

  • .compile(optimizer, loss, metrics) for settings of optimizer, loss and metrics
  • .fit(x, y, epochs, batch_size, ...) takes raw numpy input for training
  • .evaluate(x, y) outputs a dict result of your metrics
  • .predict(x) for doing predictions

And DataLoader workflow is more flexible and of pytorch style.

  • .compile(optimizer, loss, metrics) same as NumPy workflow
  • .fit_dl(train_loader, val_loader, epochs) for training the model via DataLoader
  • .evaluate_dl(data_loader) same as NumPy workflow but takes DataLoader
  • .predict_dl(data_loader) same as NumPy workflow but takes DataLoader

The two workflows can be mixed.

MNIST example

Here we show a complete example of training a ConvNet on MNIST.

import torch
import torchvision
from torch import nn

import keras4torch as k4t

Step1: Preprocess data

mnist = torchvision.datasets.MNIST(root='./', download=True)
x, y = mnist.train_data.unsqueeze(1), mnist.train_labels

x = x.float() / 255.0    # scale the pixels to [0, 1]

x_train, y_train = x[:40000], y[:40000]
x_test, y_test = x[40000:], y[40000:]

Step2: Define the model

If you have a nn.Module already, just wrap it via k4t.Model. For example,

model = torchvision.models.resnet50(num_classes=10)

model = k4t.Model(model)

For building models from scratch, you can use KerasLayer (located in k4t.layers) for automatic shape inference, which can free you from calculating the input channels.

As is shown below, k4t.layers.Conv2d(32, kernel_size=3) equals nn.Conv2d(?, 32, kernel_size=3) where the first parameter ? (i.e. in_channels) will be determined by itself.

model = torch.nn.Sequential(
    k4t.layers.Conv2d(32, kernel_size=3), nn.ReLU(),
    nn.MaxPool2d(2, 2), 
    k4t.layers.Conv2d(64, kernel_size=3), nn.ReLU(),
    nn.Flatten(),
    k4t.layers.Linear(10)
)

A model containing KerasLayer needs an extra .build(input_shape) operation.

model = k4t.Model(model).build([1, 28, 28])

Step3: Summary the model

model.summary()
=========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
=========================================================================================
├─Conv2d*: 1-1                           [-1, 32, 26, 26]          320
├─ReLU: 1-2                              [-1, 32, 26, 26]          --
├─MaxPool2d: 1-3                         [-1, 32, 13, 13]          --
├─Conv2d*: 1-4                           [-1, 64, 11, 11]          18,496
├─ReLU: 1-5                              [-1, 64, 11, 11]          --
├─Flatten: 1-6                           [-1, 7744]                --
├─Linear*: 1-7                           [-1, 10]                  77,450
=========================================================================================
Total params: 96,266
Trainable params: 96,266
Non-trainable params: 0
Total mult-adds (M): 2.50
=========================================================================================

Step4: Config optimizer, loss and metrics

model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])

If GPU is available, it will be used automatically. You can also pass device parameter to .compile() explicitly.

Step5: Training

history = model.fit(x_train, y_train,
                	epochs=30,
                	batch_size=512,
                	validation_split=0.2,
                	)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 2.8s - loss: 0.6109 - acc: 0.8372 - val_loss: 0.2712 - val_acc: 0.9235 - lr: 1e-03
Epoch 2/30 - 1.5s - loss: 0.2061 - acc: 0.9402 - val_loss: 0.1494 - val_acc: 0.9579 - lr: 1e-03
Epoch 3/30 - 1.5s - loss: 0.1202 - acc: 0.9653 - val_loss: 0.0974 - val_acc: 0.9719 - lr: 1e-03
Epoch 4/30 - 1.5s - loss: 0.0835 - acc: 0.9757 - val_loss: 0.0816 - val_acc: 0.9769 - lr: 1e-03
... ...

Step6: Plot learning curve

history.plot(kind='line', y=['loss', 'val_loss'])

Step7: Evaluate on test set

model.evaluate(x_test, y_test)
{'loss': 0.06655170023441315, 'acc': 0.9839999675750732}

Communication

We have activated Github Discussion for Q&A and most general topics!

For bugs report, please use Github Issues.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras4torch-1.1.0a0.tar.gz (27.5 kB view details)

Uploaded Source

Built Distribution

keras4torch-1.1.0a0-py3-none-any.whl (32.0 kB view details)

Uploaded Python 3

File details

Details for the file keras4torch-1.1.0a0.tar.gz.

File metadata

  • Download URL: keras4torch-1.1.0a0.tar.gz
  • Upload date:
  • Size: 27.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.1.0a0.tar.gz
Algorithm Hash digest
SHA256 2b9c7ef2d82968ed86b71fd33eaf86bbf818edf2945876d0189eb1ba785d9a33
MD5 36c694157233780f60b9e7599166e0ac
BLAKE2b-256 cd6ff18ca92503396856365d45939a6cdc5a713c988e354af947b04689315cdf

See more details on using hashes here.

Provenance

File details

Details for the file keras4torch-1.1.0a0-py3-none-any.whl.

File metadata

  • Download URL: keras4torch-1.1.0a0-py3-none-any.whl
  • Upload date:
  • Size: 32.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.1.0a0-py3-none-any.whl
Algorithm Hash digest
SHA256 e7f718a611703138b83660c5a7de8bbffc52acda239a2cc329f4e216b32367c2
MD5 f54a19dae7be90f71d067d1c89cdadd5
BLAKE2b-256 100fac1caf20fc3a36284749526669d865d1e41c6f69af416c1666c81e9e08c8

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page