A dynamic progress indicator that displays train and validation losses.
Project description
Loss Watch
Loss-watch is a library that allows you to watch your training / validation losses in a neat and orderly way, both for your jupyter notebooks and your console.
I created it because I was tired of endless blocks of printed losses in my console / notebooks. Loss watch, instead, gives you a progress bar with colors (or contrasts in the terminal) indicating your model's progress.
Installation
pip install loss-watch
Usage
Similarly to tqdm
, loss-watch plots act as the iterable you loop through while training your model. The simplest way of using it is as follows
from loss_watch import LossProgressBar
epochs = 100
for epoch, update in LossProgressBar(epochs):
# Perform your models training step and retrieve a float `loss`
train_loss = train_step()
update(train_loss)
It does not really matter how you get your training loss here, any float works. This will give you a plot that looks something like this:
As you can see, your highest loss is displayed in red, and the lowest in a light cyan.
Validation
Once in a while, you would probably like to validate your model on one or multiple validation sets. Your Loss progress bar can handle as many as you like. And in contrast to your training loss, you can validate in any interval you like! Simply pass a named float to the update
function, and it will generate another progress bar, that corresponds to this name, for you.
If you extend our example to the following:
for epoch, update in LossProgressBar(epochs):
# Perform your models training step and retrieve a float `loss`
train_loss = train_step()
if epoch % 10 == 9:
# A cheap validation
val_loss1 = val_step()
update(val_step1=val_loss1)
if epoch % 25 == 24:
# A more expensive validation
val_loss2 = val_step2()
update(val_step2=val_loss2)
update(train_loss)
As you can see, you can call update multiple times. You can also not call update at all for any of your losses including the train loss. Missing values will be interpolated.
The resulting plot will look something like this:
The black parts indicate that there was no validation data yet. You can also see that the min loss and max loss of the training step are marked in red and cyan. This is because all plots will share the colors for the minimum and maximum loss, making it easier for you to compare them. The marked losses correspond to the overall minimum and maximum here, respectively.
Quick Mode
While the above examples always illustrated the training and validation steps as functions returning a float, you are allowed to obtain these values however you like. However, if you already packaged your training/validation logic into functions, you can plot your losses in a one-liner:
LossProgressBar.run(epochs=epochs, train_step=train_step, val_step1=val_step, val_step2=val_step2)
The above code will run each training and validation step once per epoch.
Hint: If you don't want to run every validation in every epoch, you will have to define a stateful variable e.g.
epoch
yourself and check within the validation step whether it is time to validate.
Palettes
If you don't like the standard palette, you can optionally use seaborn colormaps as well. Make sure seaborn is installed, and before your training loop write for example:
from loss_watch import set_palette
set_palette("Spectral")
This will select the Spectral
colormap and use this to display your plot. Your losses, then, will look something like this:
Of course, this feature does not work in the terminal.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file loss_watch-0.2.tar.gz
.
File metadata
- Download URL: loss_watch-0.2.tar.gz
- Upload date:
- Size: 70.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 34005acd35fe1f209b30a15b5846633cbed290d7cc788dc43b6261b5fd5a2a1b |
|
MD5 | 1f43c3cf8f19449bd1b68ee2f8590f88 |
|
BLAKE2b-256 | 6085c11d5eb8f935b87d52d3c6e1154eb18f39cf550af721fdc3a9b7b7426783 |
File details
Details for the file loss_watch-0.2-py3-none-any.whl
.
File metadata
- Download URL: loss_watch-0.2-py3-none-any.whl
- Upload date:
- Size: 13.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f04d4cfbf6e9b6a44f85bdcdccf5a0b694b4977a7c4d9d94ae6e2ad81732b62 |
|
MD5 | a3e3b10845519e73ac103fb79a115846 |
|
BLAKE2b-256 | 0bfef4cfdc1ac0c819efb5c9662f15a550aae9450703a1626ec3dda4b31b0f41 |