Skip to main content

pytorch❤️keras

Project description

Pytorch❤️Keras

English | 简体中文

The torchkeras library is a simple tool for training neural network in pytorch jusk in a keras style. 😋😋

1, Introduction

With torchkeras, You need not to write your training loop with many lines of code, all you need to do is just

like these two steps as below:

(i) create your network and wrap it and the loss_fn together with torchkeras.KerasModel like this: model = torchkeras.KerasModel(net,loss_fn=nn.BCEWithLogitsLoss()).

(ii) fit your model with the training data and validate data.

The main code of use torchkeras is like below.

import torch 
import torchkeras

model = torchkeras.KerasModel(net,
                              loss_fn = nn.BCEWithLogitsLoss(),
                              optimizer= torch.optim.Adam(net.parameters(),lr = 0.001),
                              metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
                             )
dfhistory=model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=20, 
                    patience=3, 
                    ckpt_path='checkpoint',
                    monitor="val_acc",
                    mode="max",
                    plot=True
                   )

Besides,You can use torchkeras.VLog to get the dynamic training visualization any where as you like ~

import time
import math,random
from torchkeras import VLog

epochs = 10
batchs = 30

#0, init vlog
vlog = VLog(epochs, monitor_metric='val_loss', monitor_mode='min') 

#1, log_start 
vlog.log_start() 

for epoch in range(epochs):
    
    #train
    for step in range(batchs):
        
        #2, log_step (for training step)
        vlog.log_step({'train_loss':100-2.5*epoch+math.sin(2*step/batchs)}) 
        time.sleep(0.05)
        
    #eval    
    for step in range(20):
        
        #3, log_step (for eval step)
        vlog.log_step({'val_loss':100-2*epoch+math.sin(2*step/batchs)},training=False)
        time.sleep(0.05)
        
    #4, log_epoch
    vlog.log_epoch({'val_loss':100 - 2*epoch+2*random.random()-1,
                    'train_loss':100-2.5*epoch+2*random.random()-1})  

# 5, log_end
vlog.log_end()

This project seems somehow powerful, but the source code is very simple.

Actually, only about 200 lines of Python code.

If you want to understand or modify some details of this project, feel free to read and change the source code!!!


2, Features

The main features supported by torchkeras are listed below.

Versions when these features are introduced and the libraries which they used or inspired from are given.

features supported from version used or inspired by library
✅ training progress bar 3.0.0 use tqdm,inspired by keras
✅ training metrics 3.0.0 inspired by pytorch_lightning
✅ notebook visualization in traning 3.8.0 inspired by fastai
✅ early stopping 3.0.0 inspired by keras
✅ gpu training 3.0.0 use accelerate
✅ multi-gpus training(ddp) 3.6.0 use accelerate
✅ fp16/bf16 training 3.6.0 use accelerate
✅ tensorboard callback 3.7.0 use tensorboard
✅ wandb callback 3.7.0 use wandb
✅ VLog 3.9.5 use matplotlib

3, Basic Examples

You can follow these full examples to get started with torchkeras.

example read notebook code run example in kaggle
①kerasmodel basic 🔥🔥 torchkeras.KerasModel example
Open In Kaggle

②kerasmodel wandb 🔥🔥🔥 torchkeras.KerasModel with wandb demo
Open In Kaggle

③kerasmodel tunning 🔥🔥🔥 torchkeras.KerasModel with wandb sweep demo
Open In Kaggle

④kerasmodel tensorboard torchkeras.KerasModel with tensorboard example
⑤kerasmodel ddp/tpu torchkeras.KerasModel ddp tpu examples
Open In Kaggle

⑥ VLog for lightgbm/ultralytics/transformers🔥🔥🔥 VLog example

4, Advanced Examples

In some using cases, because of the differences of the model input types, you need to rewrite the StepRunner of KerasModel. Here are some examples.

example model library notebook
RL
ReinforcementLearning——Q-Learning🔥🔥 - Q-learning
ReinforcementLearning——DQN - DQN
Tabular
BinaryClassification——LightGBM - LightGBM
MultiClassification——FTTransformer🔥🔥🔥🔥🔥 - FTTransformer
BinaryClassification——FM - FM
BinaryClassification——DeepFM - DeepFM
BinaryClassification——DeepCross - DeepCross
CV
ImageClassification——Resnet - Resnet
ImageSegmentation——UNet - UNet
ObjectDetection——SSD - SSD
OCR——CRNN 🔥🔥 - CRNN-CTC
ImageClassification——SwinTransformer timm Swin
ObjectDetection——FasterRCNN torchvision FasterRCNN
ImageSegmentation——DeepLabV3++ segmentation_models_pytorch Deeplabv3++
InstanceSegmentation——MaskRCNN detectron2 MaskRCNN
ObjectDetection——YOLOv8 🔥🔥🔥 ultralytics YOLOv8
InstanceSegmentation——YOLOv8 🔥🔥🔥 ultralytics YOLOv8
NLP
Seq2Seq——Transformer🔥🔥 - Transformer
TextGeneration——Llama🔥 - Llama
TextClassification——BERT transformers BERT
TokenClassification——BERT transformers BERT_NER
FinetuneLLM——ChatGLM2_LoRA 🔥🔥🔥 transformers,peft ChatGLM2_LoRA
FinetuneLLM——ChatGLM2_AdaLoRA 🔥 transformers,peft ChatGLM2_AdaLoRA
FinetuneLLM——ChatGLM2_QLoRA🔥 transformers ChatGLM2_QLoRA_Kaggle
FinetuneLLM——BaiChuan13B_QLoRA🔥 transformers BaiChuan13B_QLoRA
FinetuneLLM——BaiChuan13B_NER 🔥🔥🔥 transformers BaiChuan13B_NER
FinetuneLLM——BaiChuan13B_MultiRounds 🔥 transformers BaiChuan13B_MultiRounds
FinetuneLLM——Qwen7B_MultiRounds 🔥🔥🔥 transformers Qwen7B_MultiRounds
FinetuneLLM——BaiChuan2_13B 🔥 transformers BaiChuan2_13B

If you want to understand or modify some details of this project, feel free to read and change the source code!!!

Any other questions, you can contact the author form the wechat official account below:

算法美食屋

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchkeras-4.0.2.tar.gz (6.6 MB view details)

Uploaded Source

Built Distribution

torchkeras-4.0.2-py3-none-any.whl (6.6 MB view details)

Uploaded Python 3

File details

Details for the file torchkeras-4.0.2.tar.gz.

File metadata

  • Download URL: torchkeras-4.0.2.tar.gz
  • Upload date:
  • Size: 6.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for torchkeras-4.0.2.tar.gz
Algorithm Hash digest
SHA256 f33256fecc5e4de9c55abd2a2b1f82c883aa279c89881e300757c3455d57e2eb
MD5 fe80e27cb5169bd862c4d773b73c4a8a
BLAKE2b-256 764400c5a4d70ef484506f17efc7ce8f42b0b18baa262ba75bf73497d6934d48

See more details on using hashes here.

File details

Details for the file torchkeras-4.0.2-py3-none-any.whl.

File metadata

  • Download URL: torchkeras-4.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.6 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for torchkeras-4.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2f51d4f8bee1ae0a53d726ba9a56ddda560dbf37eac2146c35aa5e83b27c39ee
MD5 dc0b945a5707e1c4d3014d48df0ff6d0
BLAKE2b-256 c1c4944a55b855d23f98c8150364ed5912e24685327db8ef200aacd3475a4ce5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page