Skip to main content

用于简化 torch 模型训练的工具

Project description

TorchCat

简介

TorchCat 能够用于简化你的模型训练

用法

导入库

from torch import Cat

封装你的模型

# 你的模型
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = Cat(model=net,
          loss_fn=nn.CrossEntropyLoss(),
          optimizer=torch.optim.Adam(net.parameters(), lr=0.0001))

参数

说明

model

你的模型

loss_fn

选择损失函数

optimizer

选择优化器

Cat.summary()

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如 net.summary(1, 28, 28)

Cat.train()

使用 net.train(),可以开始模型的训练

log,可以记录训练时的训练日志,包括

  • 训练集损失(log['train loss']

  • 训练集准确率(log['train acc']

  • 验证集损失(log['valid loss']

  • 验证集准确率(log['validacc']

log = net.train(train_set=train_set, epochs=5, valid_set=test_set)

参数

说明

train_set

训练集

epochs

训练轮次

valid_set

验证集

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.0.5.tar.gz (10.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.0.5-py3-none-any.whl (10.0 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.0.5.tar.gz.

File metadata

  • Download URL: torchcat-0.0.5.tar.gz
  • Upload date:
  • Size: 10.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.5.tar.gz
Algorithm Hash digest
SHA256 ca284568f1ad1720e112d634d97e242c6dcca6a8c9692f958fc49facfed97df9
MD5 bb39bcfa6575738117e5cde5c07b78ec
BLAKE2b-256 e765c4769d0ce17afa979405f5484d8223929eb594f3927c2d7a8cdd5652dd37

See more details on using hashes here.

File details

Details for the file torchcat-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 10.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 389cb5d11cce33cf51c37dc5db69f2032355625dc42084dfa74ec4cd55d17d92
MD5 90e33fb52ce50b8d4093358e3ee0140a
BLAKE2b-256 9637c33c9b8590c3d1fd1c2996fb369eee1b7604a08857fa6abf164c4e90de7f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page