Skip to main content

A trainer based on Pytorch

Project description

Torch Terinador

A PyTorch-based trainer. The highlight of this tool is the separation of the loop module and the computation module within the training module, allowing for a customizable computation process. Users can rewrite the calculate function according to their specific needs (see example). Additionally, this tool supports features such as checkpoint training, data preprocessing, loss function visualization, and a series of overfitting prevention mechanisms (e.g., cosine annealing and warmup). It also enables evaluation metrics to be stored in memory or saved as CSV files.

Installation

This package needs Python>=3.7 and the version of Pytorch used in development is 2.5.1 and cuda12.4, considering the different version of cuda, the package will not install Pytorch automatically. You should check your cuda's version, install the suitable pytorch first. Then, run the command below:

pip install tortreinador 

Quick Start

from tortreinador.utils.plot import plot_line_2
from tortreinador.utils.preprocessing import load_data
from tortreinador.train import TorchTrainer, config_generator
from tortreinador.models.MDN import mdn, Mixture, NLLLoss
from tortreinador.utils.tools import xavier_init
from tortreinador.utils.View import init_weights, split_weights
import torch
import pandas as pd

data = pd.read_excel('')
data['M_total (M_E)'] = data['Mcore (M_J/10^3)'] + data['Menv (M_E)']

# Support index, e.g input_parameters = [0, 1, 2]
input_parameters = [
    'Mass (M_J)',
    'Radius (R_E)',
    'T_sur (K)',
]

output_parameters = [
    'M_total (M_E)',
    'T_int (K)',
    'P_CEB (Mbar)',
    'T_CEB (K)'
]
# Load Data, random status default as 42
t_loader, v_loader, test_x, test_y, s_x, s_y = load_data(data=data, input_parameters=input_parameters,
                                                         output_parameters=output_parameters,
                                                         if_normal=True, if_shuffle=True, batch_size=512, feature_range=(0, 1), if_double=True, n_workers=4)

model = mdn(len(input_parameters), len(output_parameters), 20, 512)
criterion = NLLLoss()
optim = torch.optim.Adam(xavier_init(model), lr=0.0001, weight_decay=0.001)

'''
    Overwrite function 'calculate' 
'''
# class Trainer(TorchTrainer):
#     def calculate(self, x, y, mode='t'):
#         x_o, x_n = x.chunk(2, dim=1)
        
#         pi, mu, sig = model(x_o, x_n)
        
#         loss = self.criterion(pi, mu, sig, y)
#         pdf = mixture(pi, mu, sig)
#         y_pred = pdf.sample()
        
#         metric_per = r2_score(y, y_pred)
        
#         return self._standard_return(loss=loss, metric_per=metric_per, mode=mode, y=y, y_pred=y_pred)

# trainer = Trainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)


trainer = TorchTrainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)

save_file_path = '/notebooks/DeepExo/Resource/MDN_ATTN_15_error/'
config = config_generator(save_file_path, warmup_epochs=5, best_metric=0.8, lr_milestones=[12, 22, 36, 67, 75, 89, 106], lr_decay_rate=0.7)
# Training
result = trainer.fit(t_loader, v_loader, **config)


# Plot line chart
result_pd = pd.DataFrame()
result_pd['epoch'] = len(result[0])
result_pd['train_r2_avg'] = result[4]
result_pd['val_r2_avg'] = result[3]

plot_line_2(y_1='train_r2_avg', y_2='val_r2_avg', df=result_pd, fig_size=(10, 6))

# If specify 'mode' in TorchTrainer as 'csv'
saved_result = pd.read_csv('/notebooks/DeepExo/train_log/log_202408280744.csv')
plot_line_2(y_1='train_loss', y_2='val_loss', df=saved_result)

Functions

Please visit https://ardentex.github.io/tortreinador/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tortreinador-0.2.3.tar.gz (24.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tortreinador-0.2.3-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file tortreinador-0.2.3.tar.gz.

File metadata

  • Download URL: tortreinador-0.2.3.tar.gz
  • Upload date:
  • Size: 24.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for tortreinador-0.2.3.tar.gz
Algorithm Hash digest
SHA256 efba96acb31fb168db6862f543e55e328ef80d088bfce23460fb30ed2cffcfea
MD5 04b8c7a92e44ad64587a5762c44f99e5
BLAKE2b-256 64a256939acbdf290a0af9db8540b8ed44043110d01b6df9690ab172cd5edb56

See more details on using hashes here.

File details

Details for the file tortreinador-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: tortreinador-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 24.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.23

File hashes

Hashes for tortreinador-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7978488a2c1593ffe4b68f85961976d41e595a6462a0f8890a3e6daf97857751
MD5 ea069565958f3dd3516cee924d0d0540
BLAKE2b-256 97e8ae6aebf7c6eea94617ef58f9bd4f76c29bd3e4ae88689576381d8631ad17

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page