Skip to main content

A keras like wrapper for pytorch

Project description

KeraTorch

Implementing Keras clone with pytorch backend.

Install

pip install keratorch

How to use

from keraTorch.model import Sequential
from keraTorch.layers import *
from keraTorch.losses import *

The data:

x_train.shape, y_train.shape, x_valid.shape, y_valid.shape
((50000, 784), (50000,), (10000, 784), (10000,))

Model definition:

model = Sequential()
model.add(Dense(100, x_train.shape[1], activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(10))
model.add(Activation('softmax'))

Doesn't actually compile anything but to look like keras we specify the loss as below. ce4softmax means crossentropy for softmax loss.

model.compile(ce4softmax)

Burrow for Fastai's learning rate finder to find best learning rate:

bs = 256
model.lr_find(x_train, y_train, bs=bs)
Min numerical gradient: 9.12E-03
Min loss divided by 10: 1.45E-02

png

We have the same .fit and .predict functions:

model.fit(x_train, y_train, bs, epochs=10, lr=1e-2)
epoch train_loss valid_loss time
0 2.298158 2.270433 00:01
1 2.249195 2.054905 00:01
2 2.082948 1.474771 00:01
3 1.806854 0.904923 00:01
4 1.526004 0.737786 00:01
5 1.293055 0.705958 00:01
6 1.105806 0.666755 00:01
7 0.958004 0.687373 00:01
8 0.838495 0.696255 00:01
9 0.741785 0.697341 00:01
preds = model.predict(x_valid)
accuracy = (preds.argmax(axis=-1) == y_valid).mean()
print(f'Predicted accuracy is {accuracy:.2f}')
Predicted accuracy is 0.81

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keraTorch-0.0.4.tar.gz (12.6 kB view details)

Uploaded Source

Built Distribution

keraTorch-0.0.4-py3-none-any.whl (13.5 kB view details)

Uploaded Python 3

File details

Details for the file keraTorch-0.0.4.tar.gz.

File metadata

  • Download URL: keraTorch-0.0.4.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for keraTorch-0.0.4.tar.gz
Algorithm Hash digest
SHA256 baea52f02ed0ee350bcbd4be5c15000f7d5c0aa5f3b7ba65ba2c25e117b5a4dc
MD5 f1aff4f3d0cdf9caa9ed43c4f63afafc
BLAKE2b-256 7ac2eb8a62e3c1ee01fd5f5fd6cf35b06b5a6c3f4267c19ea8534036fd0c4244

See more details on using hashes here.

File details

Details for the file keraTorch-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: keraTorch-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 13.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for keraTorch-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 bda7a972c64897cdf37ca89f84a6d97d9d3369405bb28102a82945e87030259f
MD5 35dc6d3f5d6f5e582bde6efeab488960
BLAKE2b-256 26218178801a92016eecbc9663005b3273dd0e8b3fe335163abe4b1522bc0340

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page