Package for live visualization of metrics during training of a machine learning model
Project description
Traintorch (alpha)
Package for live visualization of model validation metrics during training of a machine learning model in jupyter notebooks. The package utilizes a sliding window mechanism to reduce memory usage.
Requirements:
pandas==0.25.1
matplotlib==3.1.1
ipython==7.8.0
numpy==1.17.2
pycm==2.2
Installation:
Latest release:
pip install traintorch
Latest Version
pip install git+https://github.com/rouzbeh-afrasiabi/traintorch.git
Example
Simple Usage
from traintorch import *
#custom metrics
first=metric('Loss',w_size=10,average=False)
second=metric('Accuracy',w_size=10,average=False)
#create an instance of traintorch
tracker=traintorch(n_custom_plots=2,main_grid_hspace=.1, figsize=(15,10),show_table=True)
#combine all metrics together
tracker.append([first,second])
range_max=1000
for i in range(0,range_max,1):
first.update(train_loss=1/(i+1),test_loss=1/(i**2+1))
second.update(y=i/(i*2+1))
tracker.plot()
Using pycm metrics and doing comparison
from traintorch import *
#custom metric
first=metric('Loss',w_size=10,average=False)
#pycm metrics
overall_selected=['ACC Macro']
cm_metrics_a=pycmMetrics(overall_selected,name='train',w_size=10)
cm_metrics_b=pycmMetrics(overall_selected,name='test',w_size=10)
#compare two metrics of the same kind
compare_a=collate(cm_metrics_a,cm_metrics_b,'ACC Macro')
#create an instance of traintorch
tracker=traintorch(n_custom_plots=1,main_grid_hspace=.1,figsize=(15,15),show_table=True)
#combine all metrics together
tracker.append([first,cm_metrics_a,cm_metrics_b,compare_a])
range_max=1000
for i in range(0,range_max,1):
actual_a=np.random.choice([0, 1], size=(20,), p=[1./3, 2./3])
predicted_a=np.random.choice([0, 1], size=(20,),p=[1-(i/range_max), i/range_max])
actual_b=np.random.choice([0, 1], size=(20,), p=[1./3, 2./3])
predicted_b=np.random.choice([0, 1], size=(20,),p=[1-(i/range_max), i/range_max])
cm_metrics_a.update(actual_a,predicted_a)
cm_metrics_b.update(actual_b,predicted_b)
first.update(train=1/(i+1),test=1/(i**2+1))
compare_a.update()
tracker.plot()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
traintorch-1.0.2.tar.gz
(8.7 kB
view hashes)