TorchCat🐱 是用于封装 PyTorch 模型的工具
Project description
TorchCat 🐱
简介
TorchCat 是用于封装 PyTorch 模型的工具
提供以下功能:
- 封装模型
- 简化模型训练
- 简化模型评估
- 记录训练日志
- 加载数据集
封装模型
使用 torchcat.Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fn、torchcat 参数
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
net = torchcat.Cat(model=net,
loss_fn=nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
| 参数 | 说明 |
|---|---|
| model | 你的模型 |
| loss_fn | 损失函数 |
| optimizer | 优化器 |
查看架构
在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)
训练模型
使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志
log = net.train(train_set=train_set, epochs=5, valid_set=test_set)
log 记录了训练时的日志,包括以下内容
- 训练集损失(log['train loss'])
- 训练集准确率(log['train acc'])
- 验证集损失(log['valid loss'])
- 验证集准确率(log['validacc'])
| 参数 | 说明 |
|---|---|
| train_set | 训练集 |
| epochs | 训练轮次 |
| valid_set | 验证集(默认 None) |
评估模型
使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、准确率。验证后模型将切换为推理模式
| 参数 | 说明 |
|---|---|
| valid_set | 验证集 |
| show | 是否输出验证集损失值、准确率(默认 True) |
| train | 验证后,模型是否且切换为训练模式(默认 False) |
其他
切换计算设备
TorchCat 提供了方法 to_cpu()、to_cuda() 用于切换计算设备(CPU 或 GPU🚀)
检查模型当前模式
使用 training 方法,模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式
模型推理
- 方法名:
__call__ - 功能描述:执行模型的前向推理过程
- 参数:
x- 输入数据 - 返回值:模型对输入数据
x的推理结果
加载数据集
使用 torchcat.ImageFolder 用于加载图片数据集
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 加载数据集
dataset = torchcat.ImageFolder(path='train image', one_hot=True, transform=transform)
| 参数 | 说明 |
|---|---|
| path | 数据集路径 |
| one_hot | 是否 One-Hot 编码(默认 False) |
| transform | 图像预处理方案(默认 None) |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchcat-0.0.9.tar.gz.
File metadata
- Download URL: torchcat-0.0.9.tar.gz
- Upload date:
- Size: 18.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5d6eb9caa20822a80947c62667c7b5f2f24d3e245bc1262991fe39c53acfbfb0
|
|
| MD5 |
d625de65ac0f61840d310bfb7f580dc3
|
|
| BLAKE2b-256 |
d8a45efb531feca1e72cbd91b8d368d0b6f7a3e3a0ccba84e349438724000749
|
File details
Details for the file torchcat-0.0.9-py3-none-any.whl.
File metadata
- Download URL: torchcat-0.0.9-py3-none-any.whl
- Upload date:
- Size: 18.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3054674b0165e5f86990bafbc191d01bd996699d8c724557ab56bb45b19b2fc8
|
|
| MD5 |
5a7c79ee41ff1047fabf51d2826a3e96
|
|
| BLAKE2b-256 |
f8ae7cf6f0410c487fcb48b6aed71ccf7b513953c2dd342d253cd6d0766566db
|