Skip to main content

TorchCat🐱 是用于封装 PyTorch 模型的工具

Project description

TorchCat 🐱

简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能:

  • 封装模型
    • 简化模型训练
    • 简化模型评估
    • 记录训练日志
  • 加载数据集

封装模型

使用 torchcat.Cat 封装你的模型。如果不进行训练,也可以忽略 loss_fntorchcat 参数

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
参数 说明
model 你的模型
loss_fn 损失函数
optimizer 优化器

查看架构

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如:net.summary(1, 28, 28)

训练模型

使用 net.train(),可以开始模型的训练。训练结束后会返回训练日志

log = net.train(train_set=train_set, epochs=5, valid_set=test_set)

log 记录了训练时的日志,包括以下内容

  • 训练集损失(log['train loss']
  • 训练集准确率(log['train acc']
  • 验证集损失(log['valid loss']
  • 验证集准确率(log['validacc']
参数 说明
train_set 训练集
epochs 训练轮次
valid_set 验证集(默认 None)

评估模型

使用 net.valid(valid_set, show=True, train=False),能够验证模型在给定验证集上的性能,包括损失值、准确率。验证后模型将切换为推理模式

参数 说明
valid_set 验证集
show 是否输出验证集损失值、准确率(默认 True)
train 验证后,模型是否且切换为训练模式(默认 False)

其他

切换计算设备

TorchCat 提供了方法 to_cpu()to_cuda() 用于切换计算设备(CPU 或 GPU🚀)

检查模型当前模式

使用 training 方法,模型当前是否处于训练模式。返回 True 表示处于训练模式,False 表示处于推理模式

模型推理

  • 方法名:__call__
  • 功能描述:执行模型的前向推理过程
  • 参数:x - 输入数据
  • 返回值:模型对输入数据 x 的推理结果

加载数据集

使用 torchcat.ImageFolder 用于加载图片数据集

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset = torchcat.ImageFolder(path='train image', one_hot=True, transform=transform)
参数 说明
path 数据集路径
one_hot 是否进行 One-Hot 编码(默认 False)
transform 图像预处理方案(默认 None)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.1.4.tar.gz (19.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.1.4-py3-none-any.whl (19.4 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.1.4.tar.gz.

File metadata

  • Download URL: torchcat-0.1.4.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.4.tar.gz
Algorithm Hash digest
SHA256 f5629b85acb0a4e8318f65a8b2d366182dadde99b9615bf5e869f0e3a5380792
MD5 5e95825bf300ff1eb38e4ce6bc4492ef
BLAKE2b-256 1bc58dc1005d086d77d85bc0b6a4989054921df57d48e1ea896419c9e94d5f7c

See more details on using hashes here.

File details

Details for the file torchcat-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 19.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c4087d9b337f6dba523a16ededd083e7935f20d9e5701c47db973e1ca9ba3602
MD5 55c0f0f98b261843271968cb22faae0b
BLAKE2b-256 34d753d1e55ddd452224053fa873566b728f3bf7149f096ec544efb61af1d75a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page