Real-time learning curve for Jupiter notebooks
Project description
lrcurve
Creates a learning-curve plot for Jupyter/Colab notebooks that is updated in real-time.
There is a framework agnostic interface lrcurve.PlotLearningCurve
that works well with PyTorch and Tensorflow and a keras wrapper
lrcurve.KerasLearningCurve
that uses the keras callback interface.
lrcurve
works with python 3.6 or newer and is distributed under the
MIT license.
Install
pip install -U lrcurve
API
Examples
Keras example
from lrcurve import KerasLearningCurve
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train.x, train.y,
epochs=100,
verbose=0,
validation_data=(validation.x, validation.y),
callbacks=[KerasLearningCurve()])
Framework agnostic example
with PlotLearningCurve() as plot:
for i in range(100):
plot.append(i, {
'loss': math.exp(-(i+1)/10),
'val_loss': math.exp(-i/10)
})
plot.draw()
time.sleep(0.1)
PyTorch example
from lrcurve import PlotLearningCurve
plot = PlotLearningCurve(
mappings = {
'loss': { 'line': 'train', 'facet': 'loss' },
'val_loss': { 'line': 'validation', 'facet': 'loss' },
'acc': { 'line': 'train', 'facet': 'acc' },
'val_acc': { 'line': 'validation', 'facet': 'acc' }
},
facet_config = {
'loss': { 'name': 'Cross-Entropy', 'limit': [0, None], 'scale': 'linear' },
'acc': { 'name': 'Accuracy', 'limit': [0, 1], 'scale': 'linear' }
},
xaxis_config = { 'name': 'Epoch', 'limit': [0, 500] }
)
with plot:
# optimize model
for epoch in range(500):
# compute loss
z_test = network(x_test)
loss_test = criterion(z_test, y_test)
optimizer.zero_grad()
z_train = network(x_train)
loss_train = criterion(z_train, y_train)
loss_train.backward()
optimizer.step()
# compute accuacy
accuacy_test = sklearn.metrics.accuracy_score(torch.argmax(z_test, 1).numpy(), y_test)
accuacy_train = sklearn.metrics.accuracy_score(torch.argmax(z_train, 1).numpy(), y_train)
# append and update
plot.append(epoch, {
'loss': loss_train,
'val_loss': loss_test,
'acc': accuacy_train,
'val_acc': accuacy_test
})
plot.draw()
Sponsor
Sponsored by NearForm Research.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
lrcurve-2.2.1.tar.gz
(70.2 kB
view details)
Built Distribution
lrcurve-2.2.1-py3-none-any.whl
(71.0 kB
view details)
File details
Details for the file lrcurve-2.2.1.tar.gz
.
File metadata
- Download URL: lrcurve-2.2.1.tar.gz
- Upload date:
- Size: 70.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 60d490123d7cd699e3a57693f35bd88652534eb982813b10bf8ae8d82ec6a0dc |
|
MD5 | 986733c8f45de37c31b0ac805f457083 |
|
BLAKE2b-256 | daef5fd55fca7607086803541861c1cb8c079932e8446e7bc9390ef47cfd3c71 |
File details
Details for the file lrcurve-2.2.1-py3-none-any.whl
.
File metadata
- Download URL: lrcurve-2.2.1-py3-none-any.whl
- Upload date:
- Size: 71.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 908382a1c09ee635301d73843c4bdfb93c6ca07abf2263c7ee1825e377abac29 |
|
MD5 | 31899f6187746adcf7fe406e24364772 |
|
BLAKE2b-256 | 468ff0dbb6c2f349ac0b50c09702024f838f0010f6ad2357fefa7856369dcc02 |