Skip to main content

A trainer based on Pytorch

Project description

Torch Terinador

A trainer based on pytorch including a train loop for MDN (Mixture Density Network), a data loader, plot line chart and a couple of techniques for avoid over fitting

Installation

This package needs Python>=3.7 and the version of Pytorch used in development is 1.13.1 and cuda11.2, considering the different version of cuda, the package will not install Pytorch automatically. You should check your cuda's version, install the suitable pytorch first. Then, run the command below:

pip install tortreinador 

Quick Start

from tortreinador.utils.plot import plot_line_2
from tortreinador.utils.preprocessing import load_data
from tortreinador.train import TorchTrainer, config_generator
from tortreinador.models.MDN import mdn, Mixture, NLLLoss
from tortreinador.utils.tools import xavier_init
from tortreinador.utils.View import init_weights, split_weights
import torch
import pandas as pd

data = pd.read_excel('')
data['M_total (M_E)'] = data['Mcore (M_J/10^3)'] + data['Menv (M_E)']

# Support index, e.g input_parameters = [0, 1, 2]
input_parameters = [
    'Mass (M_J)',
    'Radius (R_E)',
    'T_sur (K)',
]

output_parameters = [
    'M_total (M_E)',
    'T_int (K)',
    'P_CEB (Mbar)',
    'T_CEB (K)'
]
# Load Data, random status default as 42
t_loader, v_loader, test_x, test_y, s_x, s_y = load_data(data=data, input_parameters=input_parameters,
                                                         output_parameters=output_parameters,
                                                         if_normal=True, if_shuffle=True, batch_size=512, feature_range=(0, 1), if_double=True, n_workers=4)

model = mdn(len(input_parameters), len(output_parameters), 20, 512)
criterion = NLLLoss()
optim = torch.optim.Adam(xavier_init(model), lr=0.0001, weight_decay=0.001)

'''
    Overwrite function 'calculate' 
'''
# class Trainer(TorchTrainer):
#     def calculate(self, x, y, mode='t'):
#         x_o, x_n = x.chunk(2, dim=1)
        
#         pi, mu, sig = model(x_o, x_n)
        
#         loss = self.criterion(pi, mu, sig, y)
#         pdf = mixture(pi, mu, sig)
#         y_pred = pdf.sample()
        
#         metric_per = r2_score(y, y_pred)
        
#         return self._standard_return(loss=loss, metric_per=metric_per, mode=mode, y=y, y_pred=y_pred)

# trainer = Trainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)


trainer = TorchTrainer(is_gpu=True, epoch=50, optimizer=optim, model=model, criterion=criterion)

save_file_path = '/notebooks/DeepExo/Resource/MDN_ATTN_15_error/'
config = config_generator(save_file_path, warmup_epochs=5, best_metric=0.8, lr_milestones=[12, 22, 36, 67, 75, 89, 106], lr_decay_rate=0.7)
# Training
result = trainer.fit(t_loader, v_loader, **config)


# Plot line chart
result_pd = pd.DataFrame()
result_pd['epoch'] = len(result[0])
result_pd['train_r2_avg'] = result[4]
result_pd['val_r2_avg'] = result[3]

plot_line_2(y_1='train_r2_avg', y_2='val_r2_avg', df=result_pd, fig_size=(10, 6))

# If specify 'mode' in TorchTrainer as 'csv'
saved_result = pd.read_csv('/notebooks/DeepExo/train_log/log_202408280744.csv')
plot_line_2(y_1='train_loss', y_2='val_loss', df=saved_result)

Functions

Please visit https://ardentex.github.io/tortreinador/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tortreinador-0.1.7.tar.gz (23.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tortreinador-0.1.7-py3-none-any.whl (23.3 kB view details)

Uploaded Python 3

File details

Details for the file tortreinador-0.1.7.tar.gz.

File metadata

  • Download URL: tortreinador-0.1.7.tar.gz
  • Upload date:
  • Size: 23.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.20

File hashes

Hashes for tortreinador-0.1.7.tar.gz
Algorithm Hash digest
SHA256 eb35490ee3e6d5b44d5c887d4a9b531c7cbbb08e3c8981a3a6e828e035e3a522
MD5 e25d39b13b216283dad55127fc71df8b
BLAKE2b-256 cf0eb219cc0618d78eaf6a9345d563edf3c60e4f0d002f3f7aca3f6b42565bf1

See more details on using hashes here.

File details

Details for the file tortreinador-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: tortreinador-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 23.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.20

File hashes

Hashes for tortreinador-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 9a026fd999dd8853f1bf6da5525d5d12308ec5676045600556878dc8f555c631
MD5 ed2e04a27d2c4ab22dbd4d34c07c8d0b
BLAKE2b-256 8e44c7009ba09860c997ce44c6c68334a9144db4308970eacf7bc4ee48f75085

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page