A keras like wrapper for pytorch
Project description
KeraTorch
Implementing Keras clone with pytorch backend.
Install
pip install keratorch
How to use
from keraTorch.model import Sequential
from keraTorch.layers import *
from keraTorch.losses import *
The data:
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape
((50000, 784), (50000,), (10000, 784), (10000,))
Model definition:
model = Sequential()
model.add(Dense(100, x_train.shape[1], activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
Doesn't actually compile anything but to look like keras we specify the loss as below. ce4softmax means crossentropy for softmax loss.
model.compile(ce4softmax)
Burrow for Fastai's learning rate finder to find best learning rate:
bs = 256
model.lr_find(x_train, y_train, bs=bs)
Min numerical gradient: 9.12E-03
Min loss divided by 10: 1.45E-02
We have the same .fit and .predict functions:
model.fit(x_train, y_train, bs, epochs=10, lr=1e-2)
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 2.298158 | 2.270433 | 00:01 |
| 1 | 2.249195 | 2.054905 | 00:01 |
| 2 | 2.082948 | 1.474771 | 00:01 |
| 3 | 1.806854 | 0.904923 | 00:01 |
| 4 | 1.526004 | 0.737786 | 00:01 |
| 5 | 1.293055 | 0.705958 | 00:01 |
| 6 | 1.105806 | 0.666755 | 00:01 |
| 7 | 0.958004 | 0.687373 | 00:01 |
| 8 | 0.838495 | 0.696255 | 00:01 |
| 9 | 0.741785 | 0.697341 | 00:01 |
preds = model.predict(x_valid)
accuracy = (preds.argmax(axis=-1) == y_valid).mean()
print(f'Predicted accuracy is {accuracy:.2f}')
Predicted accuracy is 0.81
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keraTorch-0.0.4.tar.gz.
File metadata
- Download URL: keraTorch-0.0.4.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
baea52f02ed0ee350bcbd4be5c15000f7d5c0aa5f3b7ba65ba2c25e117b5a4dc
|
|
| MD5 |
f1aff4f3d0cdf9caa9ed43c4f63afafc
|
|
| BLAKE2b-256 |
7ac2eb8a62e3c1ee01fd5f5fd6cf35b06b5a6c3f4267c19ea8534036fd0c4244
|
File details
Details for the file keraTorch-0.0.4-py3-none-any.whl.
File metadata
- Download URL: keraTorch-0.0.4-py3-none-any.whl
- Upload date:
- Size: 13.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.2.0.post20200210 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bda7a972c64897cdf37ca89f84a6d97d9d3369405bb28102a82945e87030259f
|
|
| MD5 |
35dc6d3f5d6f5e582bde6efeab488960
|
|
| BLAKE2b-256 |
26218178801a92016eecbc9663005b3273dd0e8b3fe335163abe4b1522bc0340
|