Skip to main content

A trainer based on Pytorch

Project description

Torch Terinador

A PyTorch-based trainer. The highlight of this tool is the separation of the loop module and the computation module within the training module, allowing for a customizable computation process. Users can rewrite the calculate function according to their specific needs (see example). Additionally, this tool supports features such as checkpoint training, data preprocessing, loss function visualization, and a series of overfitting prevention mechanisms (e.g., cosine annealing and warmup). It also enables evaluation metrics to be stored in memory or saved as CSV files.

Installation

This package needs Python>=3.7 and the version of Pytorch used in development is 2.5.1 and cuda12.4, considering the different version of cuda, the package will not install Pytorch automatically. You should check your cuda's version, install the suitable pytorch first. Then, run the command below:

pip install tortreinador 

Quick Start

from tortreinador.utils.plot import plot_line_2
from tortreinador.utils.preprocessing import load_data
from tortreinador.train import TorchTrainer, config_generator
from tortreinador.models.MDN import mdn, Mixture, NLLLoss
from tortreinador.utils.tools import xavier_init
from tortreinador.utils.View import init_weights, split_weights
import torch
import pandas as pd

data = pd.read_excel('')
data['M_total (M_E)'] = data['Mcore (M_J/10^3)'] + data['Menv (M_E)']

# Support index, e.g input_parameters = [0, 1, 2]
input_parameters = [
    'Mass (M_J)',
    'Radius (R_E)',
    'T_sur (K)',
]

output_parameters = [
    'M_total (M_E)',
    'T_int (K)',
    'P_CEB (Mbar)',
    'T_CEB (K)'
]
# Load Data, random status default as 42
t_loader, v_loader, test_x, test_y, s_x, s_y = load_data(data=data, input_parameters=input_parameters,
                                                         output_parameters=output_parameters,
                                                         if_normal=True, if_shuffle=True, batch_size=512, feature_range=(0, 1), if_double=True, n_workers=4)

model = mdn(len(input_parameters), len(output_parameters), 20, 512)
criterion = NLLLoss()
optim = torch.optim.Adam(xavier_init(model), lr=0.0001, weight_decay=0.001)

'''
    Overwrite function 'calculate' 
'''
# class Trainer(TorchTrainer):
#     def calculate(self, x, y, mode='t'):
#         x_o, x_n = x.chunk(2, dim=1)
        
#         pi, mu, sig = model(x_o, x_n)
        
#         loss = self.criterion(pi, mu, sig, y)
#         pdf = mixture(pi, mu, sig)
#         y_pred = pdf.sample()
        
#         metric_per = r2_score(y, y_pred)
        
#         return self._standard_return(loss=loss, metric_per=metric_per, mode=mode, y=y, y_pred=y_pred)

# trainer = Trainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)


trainer = TorchTrainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)

save_file_path = '/notebooks/DeepExo/Resource/MDN_ATTN_15_error/'
config = config_generator(save_file_path, warmup_epochs=5, best_metric=0.8, lr_milestones=[12, 22, 36, 67, 75, 89, 106], lr_decay_rate=0.7)
# Training
result = trainer.fit(t_loader, v_loader, **config)


# Plot line chart
result_pd = pd.DataFrame()
result_pd['epoch'] = len(result[0])
result_pd['train_r2_avg'] = result[4]
result_pd['val_r2_avg'] = result[3]

plot_line_2(y_1='train_r2_avg', y_2='val_r2_avg', df=result_pd, fig_size=(10, 6))

# If specify 'mode' in TorchTrainer as 'csv'
saved_result = pd.read_csv('/notebooks/DeepExo/train_log/log_202408280744.csv')
plot_line_2(y_1='train_loss', y_2='val_loss', df=saved_result)

Functions

Please visit https://ardentex.github.io/tortreinador/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tortreinador-0.2.1.tar.gz (25.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tortreinador-0.2.1-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file tortreinador-0.2.1.tar.gz.

File metadata

  • Download URL: tortreinador-0.2.1.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for tortreinador-0.2.1.tar.gz
Algorithm Hash digest
SHA256 035f855d6d66aef6842fe5cd96e1410f5f44ee21d4d6b7d745af7c624adc6e0e
MD5 a08ed3a8dbdda90fb91ba97efb93b66b
BLAKE2b-256 28854f36338fec978ca5349ea089ca2a7685f9c7170e56b929ac58cbd34d2c69

See more details on using hashes here.

File details

Details for the file tortreinador-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: tortreinador-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 24.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for tortreinador-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 033ea5baa19261b2a6cb6a3865e6a87bd81960d58ed84ac90b6bf3009e6f397c
MD5 fc998f4313870b2afba94cc4321e2274
BLAKE2b-256 89a6003338a9d56487f68e9e8ef8891542a38946713131a0f43169cb450813e1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page