Simple deep learning toolkit
Project description
介绍
simpledl是Simple Deep Learning Toolkit的缩写,基于PyTorch,旨在用极简的接口构建神经网络和训练,能方便地复用模型,以及修改训练流程的各个环节,主要用于复现经典模型和进行一些实验。
关键接口
Data
不同任务的数据格式和处理方法都不相同,如果对数据进行抽象处理的话,会使整个代码变得极其复杂,simpledl没有针对数据抽象单独的类,而是复用PyTorch的Dataset,只要求返回的batch数据是dict格式,例如
batch = {
'src': src_seq,
'src_size': src_size,
'tgt': tgt_seq
}
在simpledl/data中提供了一些常用的数据处理类,可以方便地复用,比如TranslationDataset,根据两种语言的文本文件,构建一个Dataset。
Model
Model的接口是forward函数,表示模型的前向计算,函数的参数可以是batch的key,或者是batch,输出也要求是一个dict,比如
def forward(self, src, src_size, tgt):
# do something
output_logits = calc_output_logits()
return {
'logits': output_logits,
}
Loss
Loss的接口是数据的一个batch和模型的前向计算的输出,输出是dict,要求有key为loss,便于后面进行梯度更新 比如
def forward(self, batch, out):
return {
'loss': loss_calc(batch, out),
}
Optimizer
使用PyTorch的Optimizer,进行简单的包装,可以根据参数str构建需要的optimizer
Trainer
Trainer是主要的类,包括构建数据集、Model、Loss、Optimizer、callbacks。callbacks是训练过程中使用的,用于统计每一步训练的结果、保存断点、模型评估、earlystopping等操作。
训练的流程如下:
def train(...):
prepare_callbacks()
restore_training_if_necessary()
for e in training_epochs:
call_callbacks_epoch_begin()
for batch in dataset:
call_callbacks_batch_begin()
model_out = forward_model(batch) # 根据batch,调用Model的forward函数
loss = calc_loss(model_out, batch)
gradient_update_step() # 进行优化
call_callbacks_batch_end() # 统计结果、保存断点、调整学习率等
call_callbacks_epoch_end()
evaluate_if_necessary()
通过修改trainer,可以方便地修改数据、模型、以及训练的每个环节
示例
Machine Translation: Transformer, IWSLT2014
BLEU结果:
Model | de -> en | en -> de |
---|---|---|
Transformer | 33.27 | 27.72 |
Tied Transformers | 35.10 | 29.07 |
fairseq | 34.54 | 28.61 |
Ours | 34.36 | 28.33 |
Reference
fairseq
AllenNLP
Read TFRecord
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file simple-dl-toolkit-0.0.2.tar.gz
.
File metadata
- Download URL: simple-dl-toolkit-0.0.2.tar.gz
- Upload date:
- Size: 6.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4cf50cf4a5dab6852b9821a9fc0d0415324eef533ad900e8676de7d9b5cfe57c |
|
MD5 | 59de53e2bb35a6549e063918ed4ca772 |
|
BLAKE2b-256 | befdeee0e7dda6a961470b41c2bb3ee0cc32eac1d734835d152e657eb618e593 |
File details
Details for the file simple_dl_toolkit-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: simple_dl_toolkit-0.0.2-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7cd84b8af09a9456fd65eeb5fdf852c27a861549b09dced087e73da14268fef1 |
|
MD5 | dc77be2890c37da31f81ad147d81e7e3 |
|
BLAKE2b-256 | 2725247dc3690d35be8beb6c849fe7a162836db2e7a0ee51c5cecfe5b0b0e939 |