Skip to main content

TorchCat 是用于封装 PyTorch 模型的工具

Project description

TorchCat 🐱

简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能:

  • 简化训练过程
  • 简化测试过程
  • 记录训练日志

用法

导入 TorchCat

import torchcat

封装模型

使用 Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fntorchcat 参数

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
参数 说明
model 你的模型
loss_fn 损失函数
optimizer 优化器

查看架构

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)

训练模型

使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志

log = net.train(train_set=train_set, epochs=5, valid_set=test_set)

log 记录了训练时的日志,包括以下内容

  • 训练集损失(log['train loss']
  • 训练集准确率(log['train acc']
  • 验证集损失(log['valid loss']
  • 验证集准确率(log['validacc']
参数 说明
train_set 训练集
epochs 训练轮次
valid_set 验证集(默认 None)

验证模型

使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、准确率。验证后模型将切换为推理模式

参数 说明
valid_set 验证集
show 是否输出验证集损失值、准确率(默认 True)
train 验证后,模型是否且切换为训练模式(默认 False)

其他

切换计算设备

TorchCat 提供了方法 to_cpu()to_cuda() 用于切换计算设备(CPU 或 GPU🚀)

检查模型当前模式

使用 training 方法,模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式

模型推理

  • 方法名:__call__
  • 功能描述:执行模型的前向推理过程
  • 参数:x - 输入数据
  • 返回值:模型对输入数据 x 的推理结果

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.0.8.tar.gz (17.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.0.8-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.0.8.tar.gz.

File metadata

  • Download URL: torchcat-0.0.8.tar.gz
  • Upload date:
  • Size: 17.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.8.tar.gz
Algorithm Hash digest
SHA256 e5b8f60f4f57f2d98c73c9a86cf183920cb2a3817142086160719be49f5ddd67
MD5 cd7ac0ccb97a5b180cbbf888a3301629
BLAKE2b-256 315b3f96b50f3782ab5ff7cdd7bddac4f6336eb4f5ff36bac5b2eb7fc8d53975

See more details on using hashes here.

File details

Details for the file torchcat-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 17.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 27a86d39d41d486446676f984877a0862d15c35fbeaf6e3e7bdb73f7ab10ea07
MD5 fb7a0c0c2b8f9c1d8aab65626fa51b7a
BLAKE2b-256 3dcab5e85d5904ac891ffe0564bd7c376b771b628d96ab16d7b88e5ff11eaa17

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page