Skip to main content

TorchCat🐱 是用于封装 PyTorch 模型的工具

Project description

TorchCat 🐱

简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能:

  • 加载数据
  • 封装模型
  • 训练模型
  • 评估模型
  • 记录日志

加载数据

使用 torchcat.ImageFolder 用于加载图片数据集

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
train_set = torchcat.ImageFolder(path='train-image', transform=data_transorms, one_hot=True)
test_set = torchcat.ImageFolder(path='test-image', transform=data_transorms, one_hot=True)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)
参数 说明
path 数据集路径
transform 图像预处理方案
one_hot 是否进行 One-Hot 编码(默认 False)

封装模型

使用 torchcat.Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fnoptimizer 参数

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
).cuda()

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters()),
                   metrics=[torchcat.metrics.CrossEntropyAccuracy()])
参数 说明
model 你的模型
loss_fn 损失函数
optimizer 优化器
metrics 评估指标

查看结构

在封装模型后,可使用 net.summary(),可以查看模型的结构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)

训练模型

使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志

log = net.train(epochs=10, train_set=train_loader, valid_set=test_loader)

log 记录了训练时的日志,包含 loss 和 metrics 所定义的指标

参数 说明
epochs 训练轮次
train_set 训练集
valid_set 验证集(默认 None)

评估模型

使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、评估指标。验证后模型将保留推理模式

参数 说明
valid_set 验证集
show 是否输出验证集上损失值、评估指标(默认 True)
train 验证后是否将模型切换为训练模式(默认 False)

其他

模型推理

使用 net(x) 执行模型前向推理

切换计算设备

TorchCat 提供了方法 to_cpu()to_cuda() 用于切换计算设备(CPU 或 GPU🚀)

检查模型当前模式

使用 training 方法,查看模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.1.5.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.1.5-py3-none-any.whl (19.4 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.1.5.tar.gz.

File metadata

  • Download URL: torchcat-0.1.5.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.5.tar.gz
Algorithm Hash digest
SHA256 3181f65db0daa7dda2fdd3824429c5ab2bf885a4a32bc184f8f4a1747e0dafb0
MD5 cb0fcfa1878c0368ecfa6f4f0814c7cc
BLAKE2b-256 cd624048c1d61c1a3c002e262b571fed07d523ff4da92eb00d22157cc5132ae5

See more details on using hashes here.

File details

Details for the file torchcat-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 19.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 ee1e2ede0ae7a6f82a0eaf70d5b0855b0ed48e9d7e8701f328fc3673e9f3dc30
MD5 475ecb01f83e5b4f94fa91f2042abb96
BLAKE2b-256 23189bc89b7f9dc70517a193551f6aa0f1cbaf59da34322b055bceb828c3a3e5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page