Skip to main content

Package for live visualization of metrics during training of a machine learning model

Project description

Traintorch (alpha)

Codacy Badge

Package for live visualization of model validation metrics during training of a machine learning model in jupyter notebooks. The package utilizes a sliding window mechanism to reduce memory usage.

Requirements:

pandas==0.25.1
matplotlib==3.1.1
ipython==7.8.0
numpy==1.17.2
pycm==2.2

Installation:

Latest release:

pip install traintorch

Latest Version

pip install git+https://github.com/rouzbeh-afrasiabi/traintorch.git

Example

Simple Usage

from traintorch import *

#custom metrics
first=metric('Loss',w_size=10,average=False)
second=metric('Accuracy',w_size=10,average=False)


#create an instance of traintorch
tracker=traintorch(n_custom_plots=2,main_grid_hspace=.1, figsize=(15,10),show_table=True)
#combine all metrics together
tracker.append([first,second])


range_max=1000
for i in range(0,range_max,1):
    
    first.update(train_loss=1/(i+1),test_loss=1/(i**2+1))
    second.update(y=i/(i*2+1))
    tracker.plot()

Using pycm metrics and doing comparison

from traintorch import *


#custom metric
first=metric('Loss',w_size=10,average=False)

#pycm metrics
overall_selected=['ACC Macro']
cm_metrics_a=pycmMetrics(overall_selected,name='train',w_size=10)
cm_metrics_b=pycmMetrics(overall_selected,name='test',w_size=10)

#compare two metrics of the same kind
compare_a=collate(cm_metrics_a,cm_metrics_b,'ACC Macro')

#create an instance of traintorch
tracker=traintorch(n_custom_plots=1,main_grid_hspace=.1,figsize=(15,15),show_table=True)

#combine all metrics together
tracker.append([first,cm_metrics_a,cm_metrics_b,compare_a])


range_max=1000
for i in range(0,range_max,1):
    
    actual_a=np.random.choice([0, 1], size=(20,), p=[1./3, 2./3])
    predicted_a=np.random.choice([0, 1], size=(20,),p=[1-(i/range_max), i/range_max])
    actual_b=np.random.choice([0, 1], size=(20,), p=[1./3, 2./3])
    predicted_b=np.random.choice([0, 1], size=(20,),p=[1-(i/range_max), i/range_max])
    cm_metrics_a.update(actual_a,predicted_a)
    cm_metrics_b.update(actual_b,predicted_b)
    first.update(train=1/(i+1),test=1/(i**2+1))
    compare_a.update()
    tracker.plot()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

traintorch-1.0.2.tar.gz (8.7 kB view details)

Uploaded Source

File details

Details for the file traintorch-1.0.2.tar.gz.

File metadata

  • Download URL: traintorch-1.0.2.tar.gz
  • Upload date:
  • Size: 8.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.7.3

File hashes

Hashes for traintorch-1.0.2.tar.gz
Algorithm Hash digest
SHA256 1bcf0c870db3aada9c1403ac3f79688eeaae4912943bccfb55ba1790cbe8cda5
MD5 ba24761f1bbc97cc02bec52676528e06
BLAKE2b-256 f570a993a6d454eb22d2346361e22732cc1bbaa7bdfd935cb8a40da554382054

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page