Skip to main content

用于简化 torch 模型训练的工具

Project description

TorchCat

简介

TorchCat 能够用于简化你的模型训练

用法

导入库

from torch import Cat

封装你的模型

# 你的模型
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = Cat(model=net,
          loss_fn=nn.CrossEntropyLoss(),
          optimizer=torch.optim.Adam(net.parameters(), lr=0.0001))

参数

说明

model

你的模型

loss_fn

选择损失函数

optimizer

选择优化器

Cat.summary()

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如 net.summary(1, 28, 28)

Cat.train()

使用 net.train(),可以开始模型的训练

log,可以记录训练时的训练日志,包括

  • 训练集损失(log['train loss']

  • 训练集准确率(log['train acc']

  • 验证集损失(log['valid loss']

  • 验证集准确率(log['validacc']

log = cat.train(train_set=train_set, epochs=5, valid_set=test_set)

参数

说明

train_set

训练集

epochs

训练轮次

valid_set

验证集

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.0.4.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.0.4-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.0.4.tar.gz.

File metadata

  • Download URL: torchcat-0.0.4.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.4.tar.gz
Algorithm Hash digest
SHA256 ad0f6c3876bad335eef546e796e584af8b10a58eb4819729cf469950d52ad8b8
MD5 9de170737ba7743bbd70f61eba1f19d3
BLAKE2b-256 cee090705ef9578d629309e5a6127f2a921f995b77e9f5a55f4424dcc9c86913

See more details on using hashes here.

File details

Details for the file torchcat-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 2a0776b70054b7e40a94ee69456c9b282ee9edaad7a17936d9d0910fbc6e55e5
MD5 920c24e48f4ea5cc52e167f3077f771d
BLAKE2b-256 e410b1693b27ee86a444ae9cde17eac6292da4ab3145d3d1c2b5d04b8c05f335

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page