Skip to main content

This is a test of the setup

Project description

TorchCat

简介

TorchCat 能够用于简化你的模型训练

用法

导入库

from torch import Cat

封装你的模型

# 你的模型
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = Cat(model=net,
          loss_fn=nn.CrossEntropyLoss(),
          optimizer=torch.optim.Adam(net.parameters(), lr=0.0001))
参数 说明
model 你的模型
loss_fn 选择损失函数
optimizer 选择优化器

Cat.summary()

在封装模型后,使用 net.summary(),可以查看模型的架构。input_size 参数需填写模型的输入形状,如 net.summary(1, 28, 28)

Cat.train()

使用 net.train(),可以开始模型的训练

log,可以记录训练时的训练日志,包括

  • 训练集损失(log['train loss']
  • 训练集准确率(log['train acc']
  • 验证集损失(log['valid loss']
  • 验证集准确率(log['validacc']
log = net.train(train_set=train_set, epochs=5, valid_set=test_set)
参数 说明
train_set 训练集
epochs 训练轮次
valid_set 验证集

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchcat-0.0.6.tar.gz (3.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchcat-0.0.6-py3-none-any.whl (3.4 kB view details)

Uploaded Python 3

File details

Details for the file torchcat-0.0.6.tar.gz.

File metadata

  • Download URL: torchcat-0.0.6.tar.gz
  • Upload date:
  • Size: 3.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.6.tar.gz
Algorithm Hash digest
SHA256 9910f67b84d88b734bc350dc9119ad2e8c4361c78bfdd414ac8d642f8b3c63ce
MD5 c1e45ddd12f9da97a82166990ea3e138
BLAKE2b-256 5139c8f4ada3ff1a0fe3c0af4dd276d7e3b0f4d68f078abfddb7745c00fac644

See more details on using hashes here.

File details

Details for the file torchcat-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: torchcat-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 3.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for torchcat-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f445bb431926bd4a4a8235209b069615ea5a4fae130d32067e44d7b6a13995b5
MD5 0af03192a5ad789640994700712c37ba
BLAKE2b-256 7b14c1769dc8c30d4fea6da33e66dfcaa82e24f118f976c6cff26d856aecf6ed

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page