Skip to main content

TorchCat🐱 是用于封装 PyTorch 模型的工具

Project description

TorchCat 🐱

简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能:

  • 封装模型
    • 简化模型训练
    • 简化模型评估
    • 记录训练日志
  • 加载数据集

封装模型

使用 torchcat.Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fntorchcat 参数

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
参数 说明
model 你的模型
loss_fn 损失函数
optimizer 优化器

查看架构

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)

训练模型

使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志

log = net.train(train_set=train_set, epochs=5, valid_set=test_set)

log 记录了训练时的日志,包括以下内容

  • 训练集损失(log['train loss']
  • 训练集准确率(log['train acc']
  • 验证集损失(log['valid loss']
  • 验证集准确率(log['validacc']
参数 说明
train_set 训练集
epochs 训练轮次
valid_set 验证集(默认 None)

评估模型

使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、准确率。验证后模型将切换为推理模式

参数 说明
valid_set 验证集
show 是否输出验证集损失值、准确率(默认 True)
train 验证后,模型是否且切换为训练模式(默认 False)

其他

切换计算设备

TorchCat 提供了方法 to_cpu()to_cuda() 用于切换计算设备(CPU 或 GPU🚀)

检查模型当前模式

使用 training 方法,模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式

模型推理

  • 方法名:__call__
  • 功能描述:执行模型的前向推理过程
  • 参数:x - 输入数据
  • 返回值:模型对输入数据 x 的推理结果

加载数据集

使用 torchcat.ImageFolder 用于加载图片数据集

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset = torchcat.ImageFolder(path='train image', one_hot=True, transform=transform)
参数 说明
path 数据集路径
one_hot 是否进行 One-Hot 编码(默认 False)
transform 图像预处理方案(默认 None)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.1.1.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.1.1-py3-none-any.whl (18.8 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.1.1.tar.gz.

File metadata

  • Download URL: torchcat-0.1.1.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b96a53b5a657e5ff0421c534869cf8dac3979f32e378d3a14459486ca18f6f59
MD5 0c465d58939b9fc3904bb3c839448634
BLAKE2b-256 44f4715a149aaaaa1ddd8bfc7ca750d31ee4e387c5e9d9ee061a8b98dfd4d5ec

See more details on using hashes here.

File details

Details for the file torchcat-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 18.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a709ad3e9c0fbc1d59703820a17cb14715cba017d31e421c3d957a081296cc7d
MD5 e1648003cc5f631f4de34d0f3b938f2a
BLAKE2b-256 bcb022ee350e089e3c178fd90a7864b7c8aed14e84b7730f41a224490ac0ba3a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page