Use torch like keras
Project description
torch4keras
Use torch like keras
Documentation | Examples | Source code
1. 下载安装
安装稳定版
pip install torch4keras
安装最新版
pip install git+https://github.com/Tongjilibo/torch4keras.git
2. 功能
-
简述:抽象出来的Trainer,适用于一般神经网络的训练,仅需关注网络结构代码
-
特色:进度条展示训练过程,自定义metric,自带Evaluator, Checkpoint, Tensorboard, Logger等Callback,也可自定义Callback
-
初衷:前期功能是作为bert4torch和rec4torch的Trainer
-
训练:
2022-10-28 23:16:10 - Start Training 2022-10-28 23:16:10 - Epoch: 1/5 5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] test_acc: 0.98045. best_test_acc: 0.98045 2022-10-28 23:16:27 - Epoch: 2/5 5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] test_acc: 0.98280. best_test_acc: 0.98280 2022-10-28 23:16:44 - Epoch: 3/5 5000/5000 [==============================] - 15s 3ms/step - loss: 0.0284 - acc: 0.9915 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 673.60it/s] test_acc: 0.98365. best_test_acc: 0.98365 2022-10-28 23:17:03 - Epoch: 4/5 5000/5000 [==============================] - 15s 3ms/step - loss: 0.0179 - acc: 0.9948 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 692.34it/s] test_acc: 0.98265. best_test_acc: 0.98365 2022-10-28 23:17:21 - Epoch: 5/5 5000/5000 [==============================] - 14s 3ms/step - loss: 0.0129 - acc: 0.9958 Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 701.77it/s] test_acc: 0.98585. best_test_acc: 0.98585 2022-10-28 23:17:37 - Finish Training
3. 快速上手
- 参考bert4torch的训练过程
- 简单示例: turorials_mnist
4. 版本说明
- v0.0.7.post220230517 Checkpoint Calback增加保存scheduler, save_weights可自行创建目录,Logger, Tensorboard模块加入lr, 修改predict和add_trainer
- v0.0.7:20230505 独立出callbacks.py文件, fit允许输入形式为字典,load_weights支持list输入,save_weights支持仅保存可训练参数
- v0.0.6:20230212 增加resume_from_checkpoint和save_to_checkpoint;增加add_trainer方法,重构了Trainer(BaseModel)的实现(增加几个成员变量、增加initilize()、删除对forward参数个数的判断、dp和ddp不解析module、修改use_amp参数为mixed_precision),增加了AccelerateCallback
- v0.0.5:20221217 增加Summary的Callback, 增加Tqdm的进度条展示,保留原有BaseModel的同时,增加Trainer(不从nn.Module继承), 从bert4torch的snippets迁移部分通用函数
- v0.0.4:20221127 为callback增加on_train_step_end方法, 修复BaseModel(net)方式的bug
- v0.0.3.post2:20221107 修复DDP下打印的bug
- v0.0.3:20221106 参考Keras修改了callback的逻辑
- v0.0.2:20221023 增加Checkpoint, Evaluator等自带Callback, 修改BaseModel(net)方式,修复DP和DDP的__init__()
- v0.0.1:20221019 初始版本
5. 更新:
- 20230517:Checkpoint Calback增加保存scheduler, save_weights可自行创建目录,Logger, Tensorboard模块加入lr, 修改predict和add_trainer
- 20230505:独立出callbacks.py文件, fit允许输入形式为字典,load_weights支持list输入,save_weights支持仅保存可训练参数
- 20230212:增加hf的accelerator测试用例, ddp需要外部控制执行callback, 混合精度支持bf16
- 20230116:增加resume_from_checkpoint和save_to_checkpoint,动态为nn.Module增加Trainer的方法add_trainer
- 20221217:保留原有BaseModel的同时,增加Trainer(不从nn.Module继承), 从bert4torch的snippets迁移部分通用函数
- 20221203:增加Summary的Callback, 增加Tqdm的进度条展示
- 20221127:为callback增加on_train_step_end方法, 修复BaseModel(net)方式的bug
- 20221107:修复DDP下打印的bug,metrics中加入detach和auc
- 20221106:默认的Tensorboard的global_step+1, 参考Keras修改了callback的逻辑
- 20221020:增加Checkpoint, Evaluator等自带Callback, 修改BaseModel(net)方式,修复DP和DDP的__init__()
- 20221019:初版提交
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch4keras-0.0.7.post2.tar.gz
(22.1 kB
view hashes)