Skip to main content

A Ready-to-Use Wrapper for Training PyTorch Models✨

Project description

Keras4Torch

A Ready-to-Use Wrapper for Training PyTorch Models✨

PyPI Downloads CodeCov License

DocumentationsDev LogsMini Version

keras4torch provides a high-level API to train PyTorch models compatible with Keras. This project is designed for beginner with these objectives:

  • Help people who are new to PyTorch but familar with Keras
  • Reduce the cost for migrating Keras model implementation to PyTorch
  • Enable Keras users to test PyTorch model via their familar pipeline

Installation

pip install keras4torch

PyTorch 1.6+ and Python 3.6+ is required.

Quick start

Suppose you have a nn.Module to train.

model = torchvision.models.resnet18(num_classes=10)

All you need to do is wrapping it via k4t.Model().

import keras4torch as k4t

model = k4t.Model(model)

Now, there're two workflows can be used for training.

The NumPy workflow is compatible with Keras.

  • .compile(optimizer, loss, metrics) for settings of optimizer, loss and metrics
  • .fit(x, y, epochs, batch_size, ...) takes raw numpy input for training
  • .evaluate(x, y) outputs a dict result of your metrics
  • .predict(x) for doing predictions

And DataLoader workflow is more flexible and of pytorch style.

  • .compile(optimizer, loss, metrics) same as NumPy workflow
  • .fit_dl(train_loader, epochs, val_loader) for training the model via DataLoader
  • .evaluate_dl(data_loader) same as NumPy workflow but takes DataLoader
  • .predict_dl(data_loader) same as NumPy workflow but takes DataLoader

The two workflows can be mixed.

MNIST example

Here we show a complete example of training a ConvNet on MNIST.

import torch
import torchvision
from torch import nn

import keras4torch as k4t

Step1: Preprocess data

mnist = torchvision.datasets.MNIST(root='./', download=True)
x, y = mnist.train_data.unsqueeze(1), mnist.train_labels

x = x.float() / 255.0    # scale the pixels to [0, 1]

x_train, y_train = x[:40000], y[:40000]
x_test, y_test = x[40000:], y[40000:]

Step2: Define the model

If you have a nn.Module already, just wrap it via k4t.Model. For example,

model = torchvision.models.resnet50(num_classes=10)

model = k4t.Model(model)

For building models from scratch, you can use KerasLayer (located in k4t.layers) for automatic shape inference, which can free you from calculating the input channels.

As is shown below, k4t.layers.Conv2d(32, kernel_size=3) equals nn.Conv2d(?, 32, kernel_size=3) where the first parameter ? (i.e. in_channels) will be determined by itself.

model = torch.nn.Sequential(
    k4t.layers.Conv2d(32, kernel_size=3), nn.ReLU(),
    nn.MaxPool2d(2, 2), 
    k4t.layers.Conv2d(64, kernel_size=3), nn.ReLU(),
    nn.Flatten(),
    k4t.layers.Linear(10)
)

A model containing KerasLayer needs an extra .build(input_shape) operation.

model = k4t.Model(model).build([1, 28, 28])

Step3: Summary the model

model.summary()
=========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
=========================================================================================
├─Conv2d*: 1-1                           [-1, 32, 26, 26]          320
├─ReLU: 1-2                              [-1, 32, 26, 26]          --
├─MaxPool2d: 1-3                         [-1, 32, 13, 13]          --
├─Conv2d*: 1-4                           [-1, 64, 11, 11]          18,496
├─ReLU: 1-5                              [-1, 64, 11, 11]          --
├─Flatten: 1-6                           [-1, 7744]                --
├─Linear*: 1-7                           [-1, 10]                  77,450
=========================================================================================
Total params: 96,266
Trainable params: 96,266
Non-trainable params: 0
Total mult-adds (M): 2.50
=========================================================================================

Step4: Config optimizer, loss and metrics

model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])

If GPU is available, it will be used automatically. You can also pass device parameter to .compile() explicitly.

Step5: Training

history = model.fit(x_train, y_train,
                	epochs=30,
                	batch_size=512,
                	validation_split=0.2,
                	)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 2.8s - loss: 0.6109 - acc: 0.8372 - val_loss: 0.2712 - val_acc: 0.9235 - lr: 1e-03
Epoch 2/30 - 1.5s - loss: 0.2061 - acc: 0.9402 - val_loss: 0.1494 - val_acc: 0.9579 - lr: 1e-03
Epoch 3/30 - 1.5s - loss: 0.1202 - acc: 0.9653 - val_loss: 0.0974 - val_acc: 0.9719 - lr: 1e-03
Epoch 4/30 - 1.5s - loss: 0.0835 - acc: 0.9757 - val_loss: 0.0816 - val_acc: 0.9769 - lr: 1e-03
... ...

Step6: Plot learning curve

history.plot(kind='line', y=['loss', 'val_loss'])

Step7: Evaluate on test set

model.evaluate(x_test, y_test)
{'loss': 0.06655170023441315, 'acc': 0.9839999675750732}

Communication

We have activated Github Discussion for Q&A and most general topics!

For bugs report, please use Github Issues.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras4torch-1.0.7a2.tar.gz (26.8 kB view details)

Uploaded Source

Built Distribution

keras4torch-1.0.7a2-py3-none-any.whl (30.6 kB view details)

Uploaded Python 3

File details

Details for the file keras4torch-1.0.7a2.tar.gz.

File metadata

  • Download URL: keras4torch-1.0.7a2.tar.gz
  • Upload date:
  • Size: 26.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.0.7a2.tar.gz
Algorithm Hash digest
SHA256 c55a8373058d389a0bb5fc8ab94730484c71ec73a56f2b9ed598ef50c9df2e59
MD5 fe94665bb71525fbfd713ec378832901
BLAKE2b-256 c0617ff20b7b93857f293546f9ff2b75bd4bb72e2be53f357999b2790dcb9794

See more details on using hashes here.

Provenance

File details

Details for the file keras4torch-1.0.7a2-py3-none-any.whl.

File metadata

  • Download URL: keras4torch-1.0.7a2-py3-none-any.whl
  • Upload date:
  • Size: 30.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/51.0.0.post20201207 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.8.5

File hashes

Hashes for keras4torch-1.0.7a2-py3-none-any.whl
Algorithm Hash digest
SHA256 0e06657d03a9516127d6506b7b3ba54fd2e2da65ee8123a243fddb0273bbe0c3
MD5 7ffd7160c2ac8069a2036c09c865b166
BLAKE2b-256 fba367d19dbc834d697a6ea0f2600800d4ef1c6fbcc03eca3f250c0672146832

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page