building deep learning framework
Project description
stensor介绍
stensor是一种轻量化的深度学习训练/推理框架,对标pytorch提供的接口。
主要特性包括: - 自动微分 - 算子入参类型校验 - tensor接口注册机制
快速入门
参考test_transfomer实现翻译任务,主要包括以下流程:
1. 设置超参数
1).arg = ArgConfig()
2. 读取数据
1).创建Dateset
2).Dataset传递给DataLoader
3. 创建模型
1).模型初始化 model = Model(arg)
4. 创建优化器
1).优化器初始化 opt = Optimizer(arg_opt, model.parameters())
5. 训练
1).DataLoader迭代产生训练数据提供给模型
2).模型计算 y = model(x)
3).损失函数计算 loss = loss_f(y, target)
4).反向传播 loss.backward()
5).优化器更新参数 opt.update()
6. 推理
目录层级
model/—— 模型库stensor/—— 源代码目录common/—— tensor类定义dataset/—— dataload类nn/—— nn部分代码ops/—— 算子库__init__.py—— 暴露的所有api接口config.py—— 上下文环境管理
tests/—— 测试用例requirements.txt—— 依赖的三方库README.md—— 当前文件LICENSE—— 许可证文件.gitignore—— Git忽略文件列表
common模块
common/—— 源代码目录__init__.py—— 暴露接口_register_for_tensor.py—— 在functional.py中注册tensor接口tensor.py—— Tensor类utils.py—— 工具函数
Tensor类的设计思路:
1. Tensor类属性self.data中承载真实数据。
2. 以loss的输出Tensor开始,调用backward接口进行自动微分流程。
Parameter类的设计思路: 1. 继承自Tensor类,定义为模型层中的权重参数。以属性self.required_grad来判断是否进行权重更新。
dataset模块
common/—— 源代码目录__init__.py—— 暴露接口dataloaders.py—— Dataset负责建立索引到样本的映射datasets.py—— DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合transformer.py—— dataloader时的变换工具函数utils.py—— 工具函数
nn模块
nn/—— 源代码目录layer/—— 所有nn层接口__init__.py—— 暴露接口activation.py—— 激活函数层convolution.py—— 卷积层embedding.py—— 编码层linear.py—— 线性层normalization.py—— 正则化层pooling.py—— 池化层rnn.py—— RNN层
loss/—— 损失函数层__init__.py—— 暴露接口loss.py—— 损失函数层,通常作为独立的一层与模型组合
opt/—— 优化器__init__.py—— 暴露接口optimizer.py—— 优化器更新参数
__init__.py—— 暴露接口container.py—— module的容器类metric.py—— 评测指标module.py—— 模型构建的基本单元utils.py—— 工具函数
Module类的设计思路:
1.'__init__'和forward两个函数进行模型的初始化和正向计算过程的搭建。
2.用_params储存所有的参数。用_submodules储存所有的子模块。通过重载__setattr__魔术方法,在构建model时,自动将当前module的parameter和submodule存储,最终树状结构通过names_and_parameters和names_and_submodules两个接口使用yield打印出来。
3.to_gpu和to_cpu接口分别将params转化为numpy/cupy的格式。
4.load_weights和load_weights接口将numpy格式的params保存和加载。
5.plot接口序列化为dot格式文件,并调用第三方工具dot画出计算图。
Optimizer类的设计思路: 1.初始化时传入需要更新的参数。
2.zero_grad接口进行梯度重置。
3.step接口进行一次梯度更新。
4.add_hook接口在参数更新前进行hook函数的操作。
ops模块
nn/—— 源代码目录operations/—— 所有ops接口__init__.py—— 暴露接口activation_ops.py—— 激活函数算子common_ops.py—— Tensor操作算子math_ops.py—— 数学计算类算子nn_ops.py—— nn类算子utils.py—— 工具函数
__init__.py—— 暴露接口functional.py—— 所有functional接口primitive.py—— 算子构建基本单元
Primitive类的设计思路:
1.重载__call__方法,取出输入的Tensor类中承载的真实数据,调用forward进行正向计算,并封装成Tensor返回。
2.Tensor.backward接口中使用链式法则调用单算子反向计算函数Primitive.backward完成自动微分。
TODOList
- 使用cupy支持GPU。
- 使用pybind支持算子在CPU和cuda的实现。
- 完善_type_check,支持可变输入的Tensor。
- 完善dataset模块,支持Sample类。
展望
未来支持的特性包括:
元模块模版
(基础模版)提供元模块模版,用户可以通过yaml定义组建模型,自动生成编译好的图,从而进行组合。
(支持定制)对于自定义的模型, 可以通过最小粒度的元类型模版进行搭建,再调用编译接口,生成编译好的子图并保存。
动静统一
合并动态图和静态图的概念,用户组件好的模型将自动生成编译成整图,达到最优性能执行。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stensor-0.1.tar.gz
(66.5 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
stensor-0.1-py3-none-any.whl
(82.4 kB
view details)
File details
Details for the file stensor-0.1.tar.gz.
File metadata
- Download URL: stensor-0.1.tar.gz
- Upload date:
- Size: 66.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cea0a32a15286ee7291ae2205200237b62f3b4b9ba33f464922d622de80ffd2
|
|
| MD5 |
148c59d2b703d52ea42b6b57c563071e
|
|
| BLAKE2b-256 |
d780dbbda6fd85b3b48cfd9438f11c8318e5e4a68bb09feeee2e89477a08d11d
|
File details
Details for the file stensor-0.1-py3-none-any.whl.
File metadata
- Download URL: stensor-0.1-py3-none-any.whl
- Upload date:
- Size: 82.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b713ae9afa28b6f061395faa7bcdf7b105f7c02785078dae6f69b30f1ca8ddfc
|
|
| MD5 |
96aa54057bb7f43aa6c0fd1deb5d2372
|
|
| BLAKE2b-256 |
53417545edf37b4507870832cfcbb8614132fd63116854217e403834e228e2a3
|