Skip to main content

Wttch's train helper

Project description

wttch 的 AI 训练工具包

一、消息通知

1.1 钉钉 webhook 通知

钉钉 webhook 机器人API

from wttch.train.notification import DingtalkNotification

# 钉钉机器人 webhook 的链接
webhook_url = ''
# 消息签名的密钥
secret = ''

notification = DingtalkNotification(webhook_url, secret)

# 发送文本通知
notification.send_text("")
# 发送markdown
notification.send_markdown('')

1.2 企业微信 webhook 通知

企业微信机器人API

from wttch.train.notification import WechatNotification

# 企业微信机器人 webhook 的链接
webhook_url = ''
notification = WechatNotification(webhook_url)

# 发送文本通知
notification.send_text("")
# 发送markdown
notification.send_markdown('')

二、训练工具包

2.1 缓存工具

不需要修改太多代码就可以帮助你缓存数据到指定的缓存文件去。

(1). 添加 cache_wrapper; (2). 正常调用你的函数。

from wttch.train.utils import cache_wrapper

# 缓存的文件名字前缀,函数的参数会被添加到该名字后面
prefix = 'dataset'
# 缓存的文件夹位置
save_path = './dataset_cache'


@cache_wrapper(prefix=prefix, save_path=save_path)
def you_load_dataset_function():
    return {'a': 1, 'b': 2}


you_load_dataset_function()

2.2 计时器

from wttch.train.utils import StopWatch

stopwatch = StopWatch()
stopwatch.start("job 1")
# 费时操作
stopwatch.stop()
stopwatch.start("job 2")
# 费时操作
stopwatch.stop()

# 格式化打印
stopwatch.display()

2.3 进度条

简单包装了 tqdm 工具。

固定一种还不错的进度条格式。

2.3.1 循环模式

from wttch.train.utils.progress import Progress

with Progress(total=1000) as progress:
    for i in range(1000):
        # 在这里训练

        # 进度条末尾显示训练结果
        progress.train_result(loss=0.01, acc=0.02)

2.3.2 迭代器模式

from wttch.train.utils.progress import Progress

dataset = [1, 2, 3]

progress = Progress(dataset)

for data in progress:
    # 使用 data 进行训练

    # 进度条末尾显示训练结果
    progress.train_result(loss=0.01, acc=0.02)

三、torch 工具包

3.1 训练设备获取

3.1.1 获取设备

先尝试获取 cuda, 如果不支持获取 mps(macOS), 还不支持就 cpu。

可以添加 device_no 参数,但是只对 cuda 有效,表示 cuda 的序号。

from wttch.train.torch.utils import try_gpu

try_gpu(device_no=2)

3.1.2 ThreadLocal 的设备变量的设置、获取

  1. 将使用的设备写入 thread local;
  2. 需要训练设备的地方, 从 thread local 中获取设备数据;
from wttch.train.torch.utils import try_gpu, get_device_local, set_device_local

# 尝试获取 gpu 并写入 thread local
set_device_local(try_gpu(device_no=0))

# 从 thread local 读取设备
device = get_device_local()

3.2 ThreadLocal 的训练类型 dtype 的设置、获取

  1. 将使用的 dtype 写入 thread local;
  2. 需要训练类型的地方, 从 thread local 中获取设备数据;
import torch
from wttch.train.torch.utils import get_dtype_local, set_dtype_local

# 将训练的 dtype 数据类型写入 thread local
set_dtype_local(torch.float32)

# 从 thread local 读取数据类型
dtype = get_dtype_local()

3.3 argparse.ArgumentParser 预处理

from wttch.train.torch import ArgParser
from torch import nn

# 这一句话就可以使用 argparse.ArgumentParse 了,预处理了一些的东西。
args = ArgParser()
# 获取批处理...等
batch_size = args.batch_size
# 保存模型
args.save_module(nn.Linear(100, 10))

具体使用方式:

# --cuda:       使用 cuda 0 训练
# --batch-size: 批处理 64
# -m:           模型名称为 module1 可以调用
# 更多 -h: 查看帮助
python **.py --cuda 0 --batch-size 64 -m module1

# usage: 设置训练参数 [-h] [-e EPOCHS] [-b BATCH_SIZE] [-d CUDA_NO] [-t {32,64}] [-m MODULE_NAME]
# 
# options:
#   -h, --help            show this help message and exit
#   -e EPOCHS, --epochs EPOCHS
#                         训练轮次
#   -b BATCH_SIZE, --batch-size BATCH_SIZE
#                         批处理大小
#   -d CUDA_NO, --cuda-no CUDA_NO
#                         使用的 cuda 设备序号
#   -t {32,64}, --dtype {32,64}
#                         dtype 使用 float32 还是 float64
#   -m MODULE_NAME, --module-name MODULE_NAME
#                         模型保存名称

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wttch-train-helper-0.0.18.tar.gz (13.5 kB view details)

Uploaded Source

Built Distribution

wttch_train_helper-0.0.18-py3-none-any.whl (13.8 kB view details)

Uploaded Python 3

File details

Details for the file wttch-train-helper-0.0.18.tar.gz.

File metadata

  • Download URL: wttch-train-helper-0.0.18.tar.gz
  • Upload date:
  • Size: 13.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for wttch-train-helper-0.0.18.tar.gz
Algorithm Hash digest
SHA256 df12a853720e393fb28515dfec10a757bccfa907aed124e2fd6df4e26b82c864
MD5 f256dfad7331f1f54fe0cb339214ed8e
BLAKE2b-256 7c5e1c34ebe6167bc5ff06d15c2d3cd0bf697a6ab923175439937a3c75dc132c

See more details on using hashes here.

File details

Details for the file wttch_train_helper-0.0.18-py3-none-any.whl.

File metadata

File hashes

Hashes for wttch_train_helper-0.0.18-py3-none-any.whl
Algorithm Hash digest
SHA256 41e0d6d18dca3fc9a9a7cfd3a1c7d0d4c5bcff569c2e46eb33935780d2a129f9
MD5 f6e9ee221b49c90074ff74014b473228
BLAKE2b-256 50afb5a8da7a396a6bfb3b4eb86976b676fbbde5e475e38fd28d5a1003116b98

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page