An Easy-to-Use Wrapper for Training PyTorch Models
Project description
Keras4Torch
(README in English) (Documentations)
“开箱即用”的PyTorch模型训练高级API
-
对keras爱好者来说,Keras4Torch保留了绝大多数的Keras特性。你能够使用和keras相同的代码运行pytorch模型。
-
对pytorch爱好者来说,Keras4Torch使你只需要几行代码就可以完成pytorch模型的训练、评估和推理。
安装与配置
pip install keras4torch
支持PyTorch 1.6及以上。使用早期版本的torch可能导致部分功能不可用。
快速开始
作为示例,让我们开始编写一个MNIST手写数字识别程序!
import torch
import torchvision
from torch import nn
import keras4torch
Step1: 数据预处理
首先,从torchvision.datasets
中加载MNIST数据集,并将每个像素点缩放到[0, 1]之间。
其中前40000张图片作为训练集,后20000张图片作为测试集。
mnist = torchvision.datasets.MNIST(root='./', download=True)
X, y = mnist.train_data, mnist.train_labels
X = X.float() / 255.0
x_train, y_train = X[:40000], y[:40000]
x_test, y_test = X[40000:], y[40000:]
Step2: 构建模型
我们使用torch.nn.Sequential
定义一个由三层全连接组成的线性模型,激活函数为ReLU。
接着,使用keras4torch.Model
封装Sequential模型,以集成训练API。
model = torch.nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 512), nn.ReLU(),
nn.Linear(512, 128), nn.ReLU(),
nn.Linear(128, 10)
)
model = keras4torch.Model(model)
News (v0.4.1): 您也可以使用keras4torch.layers提供的KerasLayer
,以自动推算输入维度。
包含KerasLayer
的模型需要调用model.build()
,其参数是样本的维度。具体示例如下:
import keras4torch.layers as layers
model = torch.nn.Sequential(
nn.Flatten(),
layers.Linear(512), nn.ReLU(),
layers.Linear(128), nn.ReLU(),
layers.Linear(10)
)
model = keras4torch.Model(model).build(input_shape=[28, 28])
Step3: 设置优化器、损失函数和度量
model.compile()
函数对模型进行必要的配置。
参数既可以使用字符串,也可以使用torch.nn
模块中提供的类实例。
model.compile(optimizer='adam', loss=nn.CrossEntropyLoss(), metrics=['acc'])
Step4: 训练模型
model.fit()
是训练模型的方法,将以batch_size=512运行30轮次。
validation_split=0.2
指定80%数据用于训练集,剩余20%用作验证集。
history = model.fit(x_train, y_train,
epochs=30,
batch_size=512,
validation_split=0.2,
)
Train on 32000 samples, validate on 8000 samples:
Epoch 1/30 - 0.7s - loss: 0.7440 - acc: 0.8149 - val_loss: 0.3069 - val_acc: 0.9114 - lr: 1e-03
Epoch 2/30 - 0.5s - loss: 0.2650 - acc: 0.9241 - val_loss: 0.2378 - val_acc: 0.9331 - lr: 1e-03
Epoch 3/30 - 0.5s - loss: 0.1946 - acc: 0.9435 - val_loss: 0.1940 - val_acc: 0.9431 - lr: 1e-03
Epoch 4/30 - 0.5s - loss: 0.1513 - acc: 0.9555 - val_loss: 0.1663 - val_acc: 0.9524 - lr: 1e-03
... ...
Step5: 打印学习曲线
model.fit()
方法在结束时,返回关于训练历史数据的pandas.DataFrame
实例。
history.plot(kind='line', y=['acc', 'val_acc'])
Step6: 在测试集上评估
评估测试集上的损失和准确率。
model.evaluate(x_test, y_test)
{'loss': 0.121063925, 'acc': 0.9736}
社群交流
如果您在使用中遇到问题,可通过如下方式获取支持:
贡献
如果您有任何的想法和建议,请随时和我们联系,您的想法对我们非常重要。
同时也欢迎您加入我们,一同维护这个项目。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for keras4torch-0.8.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8e425f3b98f68205e40fbeaeda747319efef06e1c8988039cc40dc50f6b8d1b3 |
|
MD5 | 0773511e246a85c1dd1f0258473f7b58 |
|
BLAKE2b-256 | ca0e660295d27f19c5f2ce06ad1330a0d2c4fe9b91724f85d2940a2c8fc079bc |