Skip to main content

Simple deep learning toolkit

Project description

English | 简体中文

介绍

lessdl是Simple Deep Learning Toolkit的缩写,基于PyTorch,旨在用极简的接口构建神经网络和训练,能方便地复用模型,以及修改训练流程的各个环节,主要用于复现经典模型和进行一些实验。

关键接口

Data

不同任务的数据格式和处理方法都不相同,如果对数据进行抽象处理的话,会使整个代码变得极其复杂,lessdl没有针对数据抽象单独的类,而是复用PyTorch的Dataset,只要求返回的batch数据是dict格式,例如

batch = {
    'src': src_seq,
    'src_size': src_size, 
    'tgt': tgt_seq
}

在lessdl/data中提供了一些常用的数据处理类,可以方便地复用,比如TranslationDataset,根据两种语言的文本文件,构建一个Dataset。

Model

Model的接口是forward函数,表示模型的前向计算,函数的参数可以是batch的key,或者是batch,输出也要求是一个dict,比如

def forward(self, src, src_size, tgt):
    # do something 
    output_logits = calc_output_logits()
    return {
        'logits': output_logits,
    }

Loss

Loss的接口是数据的一个batch和模型的前向计算的输出,输出是dict,要求有key为loss,便于后面进行梯度更新 比如

def forward(self, batch, out):
    return {
        'loss': loss_calc(batch, out),
    }

Optimizer

使用PyTorch的Optimizer,进行简单的包装,可以根据参数str构建需要的optimizer

Trainer

Trainer是主要的类,包括构建数据集、Model、Loss、Optimizer、callbacks。callbacks是训练过程中使用的,用于统计每一步训练的结果、保存断点、模型评估、earlystopping等操作。

训练的流程如下:

def train(...):
    prepare_callbacks()
    restore_training_if_necessary()
    for e in training_epochs:
        call_callbacks_epoch_begin()
        for batch in dataset:
            call_callbacks_batch_begin()
            model_out = forward_model(batch)  # 根据batch,调用Model的forward函数
            loss = calc_loss(model_out, batch)
            gradient_update_step()  # 进行优化
            call_callbacks_batch_end()  # 统计结果、保存断点、调整学习率等
        call_callbacks_epoch_end()    
        evaluate_if_necessary()

通过修改trainer,可以方便地修改数据、模型、以及训练的每个环节

示例

Machine Translation: Transformer, IWSLT2014

BLEU结果:

Model de -> en en -> de
Transformer 33.27 27.72
Tied Transformers 35.10 29.07
fairseq 34.54 28.61
Ours 34.36 28.33

Reference

fairseq
AllenNLP
Read TFRecord

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lessdl-0.0.1.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

lessdl-0.0.1-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file lessdl-0.0.1.tar.gz.

File metadata

  • Download URL: lessdl-0.0.1.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.8

File hashes

Hashes for lessdl-0.0.1.tar.gz
Algorithm Hash digest
SHA256 e12aefa1cdd9cb8129b794d50bde4b38c0d3377b20620d9a41bed8104a781075
MD5 8f7fb8990481b131cd8f12da1083d756
BLAKE2b-256 36426612474710aa71dc1a247f063e4bb47e130ad37caa0d7ae9d2dde622f0e4

See more details on using hashes here.

File details

Details for the file lessdl-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: lessdl-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.8

File hashes

Hashes for lessdl-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 193c888006c681794cf8ffd76e4922d588b64bb7fad0ec4c03f74381914eb4bc
MD5 81692bdcf6af6d499ab3d7bbc55fde42
BLAKE2b-256 75cf5d0a27db603c3f4efc79327fe2b42c39dc742801cf7b17eaa226b7c51079

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page