Skip to main content

Simple deep learning toolkit

Project description

English | 简体中文

介绍

simpledl是Simple Deep Learning Toolkit的缩写,基于PyTorch,旨在用极简的接口构建神经网络和训练,能方便地复用模型,以及修改训练流程的各个环节,主要用于复现经典模型和进行一些实验。

关键接口

Data

不同任务的数据格式和处理方法都不相同,如果对数据进行抽象处理的话,会使整个代码变得极其复杂,simpledl没有针对数据抽象单独的类,而是复用PyTorch的Dataset,只要求返回的batch数据是dict格式,例如

batch = {
    'src': src_seq,
    'src_size': src_size, 
    'tgt': tgt_seq
}

在simpledl/data中提供了一些常用的数据处理类,可以方便地复用,比如TranslationDataset,根据两种语言的文本文件,构建一个Dataset。

Model

Model的接口是forward函数,表示模型的前向计算,函数的参数可以是batch的key,或者是batch,输出也要求是一个dict,比如

def forward(self, src, src_size, tgt):
    # do something 
    output_logits = calc_output_logits()
    return {
        'logits': output_logits,
    }

Loss

Loss的接口是数据的一个batch和模型的前向计算的输出,输出是dict,要求有key为loss,便于后面进行梯度更新 比如

def forward(self, batch, out):
    return {
        'loss': loss_calc(batch, out),
    }

Optimizer

使用PyTorch的Optimizer,进行简单的包装,可以根据参数str构建需要的optimizer

Trainer

Trainer是主要的类,包括构建数据集、Model、Loss、Optimizer、callbacks。callbacks是训练过程中使用的,用于统计每一步训练的结果、保存断点、模型评估、earlystopping等操作。

训练的流程如下:

def train(...):
    prepare_callbacks()
    restore_training_if_necessary()
    for e in training_epochs:
        call_callbacks_epoch_begin()
        for batch in dataset:
            call_callbacks_batch_begin()
            model_out = forward_model(batch)  # 根据batch,调用Model的forward函数
            loss = calc_loss(model_out, batch)
            gradient_update_step()  # 进行优化
            call_callbacks_batch_end()  # 统计结果、保存断点、调整学习率等
        call_callbacks_epoch_end()    
        evaluate_if_necessary()

通过修改trainer,可以方便地修改数据、模型、以及训练的每个环节

示例

Machine Translation: Transformer, IWSLT2014

BLEU结果:

Model de -> en en -> de
Transformer 33.27 27.72
Tied Transformers 35.10 29.07
fairseq 34.54 28.61
Ours 34.36 28.33

Reference

fairseq
AllenNLP
Read TFRecord

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

simple-dl-toolkit-0.0.2.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

simple_dl_toolkit-0.0.2-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file simple-dl-toolkit-0.0.2.tar.gz.

File metadata

  • Download URL: simple-dl-toolkit-0.0.2.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.8

File hashes

Hashes for simple-dl-toolkit-0.0.2.tar.gz
Algorithm Hash digest
SHA256 4cf50cf4a5dab6852b9821a9fc0d0415324eef533ad900e8676de7d9b5cfe57c
MD5 59de53e2bb35a6549e063918ed4ca772
BLAKE2b-256 befdeee0e7dda6a961470b41c2bb3ee0cc32eac1d734835d152e657eb618e593

See more details on using hashes here.

File details

Details for the file simple_dl_toolkit-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for simple_dl_toolkit-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7cd84b8af09a9456fd65eeb5fdf852c27a861549b09dced087e73da14268fef1
MD5 dc77be2890c37da31f81ad147d81e7e3
BLAKE2b-256 2725247dc3690d35be8beb6c849fe7a162836db2e7a0ee51c5cecfe5b0b0e939

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page