Add your description here
Project description
Introduction
This repository contains AI libraries commonly used for all my AI projects.
Train
The MyTrain class train subclass of huggingface PreTrainedModel. model.state_dict and model.load_state_dict must be consistent.
---
title: MyTrain.__call__
---
flowchart TD
INST[instantiate model and random generator] --> TRAININSTMETRICS[instantiate metrics] --> MODE{{evaluation only?}}
MODE -- yes --> EVALMODEL[<code>MyTrain.my_eval_model</code>]
subgraph EVALMODEL[<code>MyTrain.my_eval_model</code>]
direction TB
EVALLOOP[eval loop]
end
subgraph EVALLOOP[eval loop]
direction TB
CHECKCONSISTENCY[check config consistency] --> EVALLOADCHECKPOINT[load checkpoint for model and generator] --> EVALDEVICE[set model device] --> EVALDATALOADER[setup data loader] --> EVALEPOCHBRANCH{{implement <code>model.my_eval_epoch</code>?}}
EVALEPOCHBRANCH -- yes --> CUSTOMEVAL[<code>model.my_eval_epoch</code>]
EVALEPOCHBRANCH -- no --> COMMONEVAL[<code>MyTrain.my_eval_epoch</code>]
CUSTOMEVAL --> UPDATECONFIGPERFORM[update configuration]
COMMONEVAL --> UPDATECONFIGPERFORM
end
MODE -- no --> COMMONTRAIN[<code>MyTrain.my_train_model</code>]
subgraph COMMONTRAIN[<code>MyTrain.my_train_model</code>]
direction TB
CONTINUETRAIN{{last epoch is -1?}}
CONTINUETRAIN -- yes --> HASINIT{{implement <code>model.my_initialize_model</code>?}} -- yes --> CUSTOMINIT[<code>model.my_initialize_model</code>?]
HASINIT -- no --> INITWEIGHT[initialize model weights by <code>my_initializer</code>]
CONTINUETRAIN -- no --> TRAINLOADCHECK[load checkpoint for model and random generator]
CUSTOMINIT --> TRAINDEVICE[set model device]
INITWEIGHT --> TRAINDEVICE
TRAINLOADCHECK --> TRAINDEVICE
TRAINDEVICE --> INSTOPSC[instantiate optimizer and lr scheduler] --> CONTINUETRAIN2{{last epoch is -1?}} -- no --> TRAINCHECKOPSC[load checkpoint for optimizer and lr scheduler] --> SETUPOPSC[setup optimizer and lr_scheduler]
CONTINUETRAIN2{{last epoch is -1?}} -- yes --> SETUPOPSC
SETUPOPSC --> TRAINDATALOADER[setup data loader] --> INSTEARLYSTOP[instantiate early stopping] --> TRAINLOOP[train loop]
end
subgraph TRAINLOOP[train loop]
direction TB
M{{implement <code>model.my_train_epoch</code>?}}
M -- yes --> N[<code>model.my_train_epoch</code>]
M -- no --> O[<code>MyTrain.my_train_epoch</code>]
N --> P{{implement <code>model.my_eval_epoch</code>?}}
O --> P
P -- yes --> Q[<code>model.my_eval_epoch</code>]
P -- no --> R[<code>MyTrain.my_eval_epoch</code>]
Q --> UPDATELR[update learning rate]
R --> UPDATELR
UPDATELR --> TRAINSAVE[save epoch configuration and checkpoint] --> EARLYSTOP[check early stopping]
end
Test
The MyTest class test subclass of huggingface PreTrainedModel. MyTest will load the epoch saved by MyTrain. If model.my_train_model is implemented, then the corresponding model.my_load_model is necessary.
---
title: MyTest.__call__
---
flowchart TD
INSTMODEL[instantiate model and random generator] --> INSTMETRIC[instantiate metrics] --> LOADCHECK[load checkpoint for model and random generator] --> TESTDEVICE[set model device] --> TESTDATA[setup data loader] --> TESTMODEL[test model] --> TESTSAVE[save metrics]
Metric
The metric classes should implement three methods.
__init__intialized the parameters and metric state.stepprocess the batchs. It receives:df: the data frame returned by the model'seval_outputmethod.examples: the examples in the dataset.batch: the batch returned by the model'sdata_collator.
epochaccumulate all batch results, reinitialize the metric state and return the final metric.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file common_ai-0.1.4.tar.gz.
File metadata
- Download URL: common_ai-0.1.4.tar.gz
- Upload date:
- Size: 33.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
358a09b246b2a7e8b1dd42b874722f7715d8888f5981c90a31760106e7ef1281
|
|
| MD5 |
5377c69efa64d422e185740aa652b4a5
|
|
| BLAKE2b-256 |
31ac73d5e3eef51888e104d8c14371f8fab312c612f8aa95551c63e1a388e70b
|
Provenance
The following attestation bundles were made for common_ai-0.1.4.tar.gz:
Publisher:
release.yml on ljw20180420/common_ai
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
common_ai-0.1.4.tar.gz -
Subject digest:
358a09b246b2a7e8b1dd42b874722f7715d8888f5981c90a31760106e7ef1281 - Sigstore transparency entry: 1396245094
- Sigstore integration time:
-
Permalink:
ljw20180420/common_ai@b28614db6a20d2320bc13af25cba3620cd18dec0 -
Branch / Tag:
refs/tags/v0.1.4 - Owner: https://github.com/ljw20180420
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@b28614db6a20d2320bc13af25cba3620cd18dec0 -
Trigger Event:
release
-
Statement type: