Wttch's train helper
Project description
wttch 的 AI 训练工具包
一、消息通知
1.1 钉钉 webhook 通知
from wth.notification import DingtalkNotification
# 钉钉机器人 webhook 的链接
webhook_url = ''
# 消息签名的密钥
secret = ''
notification = DingtalkNotification(webhook_url, secret)
# 发送文本通知
notification.send_text("")
# 发送markdown
notification.send_markdown('')
1.2 企业微信 webhook 通知
from wth.notification import WechatNotification
# 企业微信机器人 webhook 的链接
webhook_url = ''
notification = WechatNotification(webhook_url)
# 发送文本通知
notification.send_text("")
# 发送markdown
notification.send_markdown('')
二、训练工具包
2.1 缓存工具
不需要修改太多代码就可以帮助你缓存数据到指定的缓存文件去。
(1). 添加 cache_wrapper; (2). 正常调用你的函数。
from wth.utils import cache_wrapper
# 缓存的文件名字前缀,函数的参数会被添加到该名字后面
prefix = 'dataset'
# 缓存的文件夹位置
save_path = './dataset_cache'
@cache_wrapper(prefix=prefix, save_path=save_path)
def you_load_dataset_function():
return {'a': 1, 'b': 2}
you_load_dataset_function()
2.2 计时器
from wth.utils import StopWatch
stopwatch = StopWatch()
stopwatch.start("job 1")
# 费时操作
stopwatch.stop()
stopwatch.start("job 2")
# 费时操作
stopwatch.stop()
# 格式化打印
stopwatch.display()
2.3 进度条
简单包装了 tqdm
工具。
固定一种还不错的进度条格式。
2.3.1 循环模式
from wth.utils.progress import Progress
with Progress(total=1000) as progress:
for i in range(1000):
# 在这里训练
# 进度条末尾显示训练结果
progress.train_result(loss=0.01, acc=0.02)
2.3.2 迭代器模式
from wth.utils.progress import Progress
dataset = [1, 2, 3]
progress = Progress(dataset)
for data in progress:
# 使用 data 进行训练
# 进度条末尾显示训练结果
progress.train_result(loss=0.01, acc=0.02)
2.4 快捷获取函数执行时间
from wth.utils import get_time
from time import sleep
@get_time
def test():
sleep(1)
执行结果:
"test()" took 1.0051 seconds to execute.
三、torch 工具包
3.1 训练设备获取
3.1.1 获取设备
先尝试获取 cuda, 如果不支持获取 mps(macOS), 还不支持就 cpu。
可以添加 device_no
参数,但是只对 cuda 有效,表示 cuda 的序号。
from wth.torch.utils import try_gpu
try_gpu(device_no=2)
3.1.2 ThreadLocal 的设备变量的设置、获取
- 将使用的设备写入 thread local;
- 需要训练设备的地方, 从 thread local 中获取设备数据;
from wth.torch.utils import try_gpu, get_device_local, set_device_local
# 尝试获取 gpu 并写入 thread local
set_device_local(try_gpu(device_no=0))
# 从 thread local 读取设备
device = get_device_local()
3.2 ThreadLocal 的训练类型 dtype 的设置、获取
- 将使用的 dtype 写入 thread local;
- 需要训练类型的地方, 从 thread local 中获取设备数据;
import torch
from wth.torch.utils import get_dtype_local, set_dtype_local
# 将训练的 dtype 数据类型写入 thread local
set_dtype_local(torch.float32)
# 从 thread local 读取数据类型
dtype = get_dtype_local()
3.3 yml 多环境配置
环境配置,使用 yml 文件。
首先书写你的 yml 文件
例如:
# 激活的环境
active: local
# 本地环境
local:
epochs: 10
batch-size: 256
dataset-path: /Volumes/Wttch/datasets/torchvision/cache
dtype: float32
cuda-no: 1
# 你的其他属性
# 学校环境
school-server:
epochs: 10
batch-size: 256
dataset-path:
active 必需标识启用的环境。 后续使用 key: 环境配置 的方式可以定义多个环境。
获取你启用的环境的变量
# 获取你的配置
from wth.torch import Config
# 定义配置
config = Config()
# 获取预定义的属性
print(config.epochs)
# 获取其他你自己的属性
print(config['you-property-key'])
3.4 TrainAndTestDataLoader 测试和训练集
from wth.torch.utils.data import TrainAndTestDataLoader
from torch.utils.data import Dataset, DataLoader
# 数据集声明, 需要是 Dataset 的子类
dataset = ... # type: Dataset
# train_rate 训练用集合的百分比
loader = TrainAndTestDataLoader(dataset, train_rate=0.8)
# 训练用 DataLoader
loader1 = loader.train_loader # type: DataLoader
loader2 = loader.test_loader # type: DataLoader
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
wttch_train_helper-0.0.31.tar.gz
(15.8 kB
view details)
Built Distribution
File details
Details for the file wttch_train_helper-0.0.31.tar.gz
.
File metadata
- Download URL: wttch_train_helper-0.0.31.tar.gz
- Upload date:
- Size: 15.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 004f95b157357e1344486498a4701fcf70b78beb7dff6567aaa2b11d80a7f8ed |
|
MD5 | 8fb2f47e88777bf58a6a7b34e2f4d381 |
|
BLAKE2b-256 | 011a7e98cef0649a4513590df3f812d91f06a3d2cea5507892919d4f604bc94e |
File details
Details for the file wttch_train_helper-0.0.31-py3-none-any.whl
.
File metadata
- Download URL: wttch_train_helper-0.0.31-py3-none-any.whl
- Upload date:
- Size: 17.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 65fa4eae2502aadd0e1285f26971251aa6b613bcfb94f07e9a0d3ca99b648814 |
|
MD5 | bf35553e894fa69d07d6ef7bc2f9b1e9 |
|
BLAKE2b-256 | 1484c67e0639b2d265bee2a69ace475d1ec62d20c1aa5ef4aaafdc9d233595ab |