Skip to main content

TorchCat🐱 是用于封装 PyTorch 模型的工具

Project description

TorchCat 🐱

简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能:

  • 封装模型
    • 简化模型训练
    • 简化模型评估
    • 记录训练日志
  • 加载数据集

封装模型

使用 torchcat.Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fntorchcat 参数

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
参数 说明
model 你的模型
loss_fn 损失函数
optimizer 优化器

查看架构

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)

训练模型

使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志

log = net.train(train_set=train_set, epochs=5, valid_set=test_set)

log 记录了训练时的日志,包括以下内容

  • 训练集损失(log['train loss']
  • 训练集准确率(log['train acc']
  • 验证集损失(log['valid loss']
  • 验证集准确率(log['validacc']
参数 说明
train_set 训练集
epochs 训练轮次
valid_set 验证集(默认 None)

评估模型

使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、准确率。验证后模型将切换为推理模式

参数 说明
valid_set 验证集
show 是否输出验证集损失值、准确率(默认 True)
train 验证后,模型是否且切换为训练模式(默认 False)

其他

切换计算设备

TorchCat 提供了方法 to_cpu()to_cuda() 用于切换计算设备(CPU 或 GPU🚀)

检查模型当前模式

使用 training 方法,模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式

模型推理

  • 方法名:__call__
  • 功能描述:执行模型的前向推理过程
  • 参数:x - 输入数据
  • 返回值:模型对输入数据 x 的推理结果

加载数据集

使用 torchcat.ImageFolder 用于加载图片数据集

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset = torchcat.ImageFolder(path='train image', one_hot=True, transform=transform)
参数 说明
path 数据集路径
one_hot 是否进行 One-Hot 编码(默认 False)
transform 图像预处理方案(默认 None)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.1.3.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.1.3-py3-none-any.whl (19.4 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.1.3.tar.gz.

File metadata

  • Download URL: torchcat-0.1.3.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.3.tar.gz
Algorithm Hash digest
SHA256 afa84dcb5c94761a8814d81e944dc65c0c1e6b46f00d16c22b6029f03d619e6f
MD5 c5e044ad56a6ee8f1efa9b6f503d00c3
BLAKE2b-256 edcebdfcf93af99f43f5533914e7520613a82128e4d92ae8788b3ecc246dc590

See more details on using hashes here.

File details

Details for the file torchcat-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 19.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 d851047f842826acb2237aafb2bd242a7d31cdb2ddad4cefa3e4e07b87cf7dc4
MD5 2f69ab1e5ae4ee02f7fc5a1e3f6ba22e
BLAKE2b-256 248f440f84e6d771a1baf8498f289ad96a35da6900904e61d9f8eb56b804ffea

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page