Real-time learning curve for Jupiter notebooks
Project description
lrcurve
Creates a learning-curve plot for Jupyter/Colab notebooks that is updated in real-time.
There is a framework agnostic interface lrcurve.PlotLearningCurve
that works well with PyTorch and Tensorflow and a keras wrapper
lrcurve.KerasLearningCurve
that uses the keras callback interface.
lrcurve
works with python 3.6 or newer and is distributed under the
MIT license.
Install
pip install -U lrcurve
API
Examples
Keras example
from lrcurve import KerasLearningCurve
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train.x, train.y,
epochs=100,
verbose=0,
validation_data=(validation.x, validation.y),
callbacks=[KerasLearningCurve()])
Framework agnostic example
with PlotLearningCurve() as plot:
for i in range(100):
plot.append(i, {
'loss': math.exp(-(i+1)/10),
'val_loss': math.exp(-i/10)
})
plot.draw()
time.sleep(0.1)
PyTorch example
from lrcurve import PlotLearningCurve
plot = PlotLearningCurve(
mappings = {
'loss': { 'line': 'train', 'facet': 'loss' },
'val_loss': { 'line': 'validation', 'facet': 'loss' },
'acc': { 'line': 'train', 'facet': 'acc' },
'val_acc': { 'line': 'validation', 'facet': 'acc' }
},
facet_config = {
'loss': { 'name': 'Cross-Entropy', 'limit': [0, None], 'scale': 'linear' },
'acc': { 'name': 'Accuracy', 'limit': [0, 1], 'scale': 'linear' }
},
xaxis_config = { 'name': 'Epoch', 'limit': [0, 500] }
)
with plot:
# optimize model
for epoch in range(500):
# compute loss
z_test = network(x_test)
loss_test = criterion(z_test, y_test)
optimizer.zero_grad()
z_train = network(x_train)
loss_train = criterion(z_train, y_train)
loss_train.backward()
optimizer.step()
# compute accuacy
accuacy_test = sklearn.metrics.accuracy_score(torch.argmax(z_test, 1).numpy(), y_test)
accuacy_train = sklearn.metrics.accuracy_score(torch.argmax(z_train, 1).numpy(), y_train)
# append and update
plot.append(epoch, {
'loss': loss_train,
'val_loss': loss_test,
'acc': accuacy_train,
'val_acc': accuacy_test
})
plot.draw()
Sponsor
Sponsored by NearForm Research.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
lrcurve-2.2.1.tar.gz
(70.2 kB
view hashes)
Built Distribution
lrcurve-2.2.1-py3-none-any.whl
(71.0 kB
view hashes)