building deep learning framework
Project description
stensor介绍
stensor是一种轻量化的深度学习训练/推理框架,对标pytorch提供的接口。
主要特性包括:
- 自动微分
- 算子入参类型校验
- tensor接口注册机制
快速入门
参考test_transfomer实现翻译任务,主要包括以下流程:
1. 设置超参数
1).arg = ArgConfig()
2. 读取数据
1).创建Dateset
2).Dataset传递给DataLoader
3. 创建模型
1).模型初始化 model = Model(arg)
4. 创建优化器
1).优化器初始化 opt = Optimizer(arg_opt, model.parameters())
5. 训练
1).DataLoader迭代产生训练数据提供给模型
2).模型计算 y = model(x)
3).损失函数计算 loss = loss_f(y, target)
4).反向传播 loss.backward()
5).优化器更新参数 opt.update()
6. 推理
目录层级
model/—— 模型库stensor/—— 源代码目录common/—— tensor类dataset/—— datase -common/—— tensor类定义dataset/—— dataload类nn/—— nn部分代码ops/—— 算子库__init__.py—— 暴露的所有api接口config.py—— 上下文环境管理
tests/—— 测试用例requirements.txt—— 依赖的三方库README.md—— 当前文件LICENSE—— 许可证文件.gitignore—— Git忽略文件列表
common模块
common/—— 源代码目录__init__.py—— 暴露接口_register_for_tensor.py—— 在functional.py中注册tensor接口tensor.py—— Tensor类utils.py—— 工具函数
Tensor类的设计思路:
1. Tensor类属性self.data中承载真实数据。
2. 以loss的输出Tensor开始,调用backward接口进行自动微分流程。
Parameter类的设计思路:
1. 继承自Tensor类,定义为模型层中的权重参数。以属性self.required_grad来判断是否进行权重更新。
dataset模块
common/—— 源代码目录__init__.py—— 暴露接口dataloaders.py—— Dataset负责建立索引到样本的映射datasets.py—— DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合 —— 源代码目录layer/—— 所有nn层接口__init__.py—— 暴露接口activation.py—— 激活函数层convolution.py—— 卷积层embedding.py—— 编码层linear.py—— 线性层normalization.py—— 正则化层pooling.py—— 池化层rnn.py—— RNN层
loss/—— 损失函数层__init__.py—— 暴露接口loss.py—— 损失函数层,通常作为独立的一层与模型组合
opt/—— 优化器__init__.py—— 暴露接口optimizer.py—— 优化器更新参数
__init__.py—— 暴露接口container.py—— module的容器类metric.py—— 评测指标module.py—— 模型构建的基本单元utils.py—— 工具函数
Module类的设计思路:
1.'__init__'和forward两个函数进行模型的初始化和正向计算过程的搭建。
2.用_params储存所有的参数。用_submodules储存所有的子模块。通过重载__setattr__魔术方法,在构建model时,自动将当前module的parameter和submodule存储,最终树状结构通过names_and_parameters和names_and_submodules两个接口使用yield打印出来。
3.to_gpu和to_cpu接口分别将params转化为numpy/cupy的格式。
4.load_weights和load_weights接口将numpy格式的params保存和加载。
5.plot接口序列化为dot格式文件,并调用第三方工具dot画出计算图。
Optimizer类的设计思路:
1.初始化时传入需要更新的参数。
2.zero_grad接口进行梯度重置。
3.step接口进行一次梯度更新。
4.add_hook接口在参数更新前进行hook函数的操作。
ops模块
nn/—— 源代码目录operations/—— 所有ops接口__init__.py—— 暴露接口activation_ops.py—— 激活函数算子common_ops.py—— Tensor操作算子math_ops.py—— 数学计算类算子nn_ops.py—— nn类算子utils.py—— 工具函数
__init__.py—— 暴露接口functional.py—— 所有functional接口primitive.py—— 算子构建基本单元
Primitive类的设计思路:
1.重载__call__方法,取出输入的Tensor类中承载的真实数据,调用forward进行正向计算,并封装成Tensor返回。
2.Tensor.backward接口中使用链式法则调用单算子反向计算函数Primitive.backward完成自动微分。
TODOList
- 使用cupy支持GPU。
- 使用pybind支持算子在CPU和cuda的实现。
- 完善_type_check,支持可变输入的Tensor。
- 完善dataset模块,支持Sample类。
展望
未来支持的特性包括:
元模块模版
(基础模版)提供元模块模版,用户可以通过yaml定义组建模型,自动生成编译好的图,从而进行组合。
(支持定制)对于自定义的模型, 可以通过最小粒度的元类型模版进行搭建,再调用编译接口,生成编译好的子图并保存。
动静统一
合并动态图和静态图的概念,用户组件好的模型将自动生成编译成整图,达到最优性能执行。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stensor-0.2.0.tar.gz
(99.1 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
stensor-0.2.0-py3-none-any.whl
(122.3 kB
view details)
File details
Details for the file stensor-0.2.0.tar.gz.
File metadata
- Download URL: stensor-0.2.0.tar.gz
- Upload date:
- Size: 99.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2192a002977f20cea6bd9ac3636df34f0fb9cb24f735acfd87d828e12cf9b8ee
|
|
| MD5 |
c70e4b60d3649ad5a28bb5742ee9b0fe
|
|
| BLAKE2b-256 |
b11762a647623a16039fe542a7ff04ac43aefb149ccc8a99070ac3ccabad6a03
|
File details
Details for the file stensor-0.2.0-py3-none-any.whl.
File metadata
- Download URL: stensor-0.2.0-py3-none-any.whl
- Upload date:
- Size: 122.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ad4f94d315c76c5c24af273bab0cfd5a9fba4bda6eeb12725ba7660968d7a93d
|
|
| MD5 |
4491c37f2b4da1d13e6cfafb8d671494
|
|
| BLAKE2b-256 |
e77e6c3efd12560a04fd5f83f8026914515a5da36bdbae024722f04108352835
|