Real-time learning curve for Jupiter notebooks
Project description
lrcurve
Creates a learning-curve plot for Jupyter/Colab notebooks that is updated in real-time.
There is a framework agnostic interface lrcurve.PlotLearningCurve
that works well with PyTorch and Tensorflow and a keras wrapper
lrcurve.KerasLearningCurve that uses the keras callback interface.
lrcurve works with python 3.6 or newer and is distributed under the
MIT license.
Install
pip install -U lrcurve
API
Examples
Keras example
from lrcurve import KerasLearningCurve
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train.x, train.y,
epochs=100,
verbose=0,
validation_data=(validation.x, validation.y),
callbacks=[KerasLearningCurve()])
Framework agnostic example
with PlotLearningCurve() as plot:
for i in range(100):
plot.append(i, {
'loss': math.exp(-(i+1)/10),
'val_loss': math.exp(-i/10)
})
plot.draw()
time.sleep(0.1)
PyTorch example
from lrcurve import PlotLearningCurve
plot = PlotLearningCurve(
mappings = {
'loss': { 'line': 'train', 'facet': 'loss' },
'val_loss': { 'line': 'validation', 'facet': 'loss' },
'acc': { 'line': 'train', 'facet': 'acc' },
'val_acc': { 'line': 'validation', 'facet': 'acc' }
},
facet_config = {
'loss': { 'name': 'Cross-Entropy', 'limit': [0, None], 'scale': 'linear' },
'acc': { 'name': 'Accuracy', 'limit': [0, 1], 'scale': 'linear' }
},
xaxis_config = { 'name': 'Epoch', 'limit': [0, 500] }
)
with plot:
# optimize model
for epoch in range(500):
# compute loss
z_test = network(x_test)
loss_test = criterion(z_test, y_test)
optimizer.zero_grad()
z_train = network(x_train)
loss_train = criterion(z_train, y_train)
loss_train.backward()
optimizer.step()
# compute accuacy
accuacy_test = sklearn.metrics.accuracy_score(torch.argmax(z_test, 1).numpy(), y_test)
accuacy_train = sklearn.metrics.accuracy_score(torch.argmax(z_train, 1).numpy(), y_train)
# append and update
plot.append(epoch, {
'loss': loss_train,
'val_loss': loss_test,
'acc': accuacy_train,
'val_acc': accuacy_test
})
plot.draw()
Sponsor
Sponsored by NearForm Research.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lrcurve-2.2.1.tar.gz.
File metadata
- Download URL: lrcurve-2.2.1.tar.gz
- Upload date:
- Size: 70.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
60d490123d7cd699e3a57693f35bd88652534eb982813b10bf8ae8d82ec6a0dc
|
|
| MD5 |
986733c8f45de37c31b0ac805f457083
|
|
| BLAKE2b-256 |
daef5fd55fca7607086803541861c1cb8c079932e8446e7bc9390ef47cfd3c71
|
File details
Details for the file lrcurve-2.2.1-py3-none-any.whl.
File metadata
- Download URL: lrcurve-2.2.1-py3-none-any.whl
- Upload date:
- Size: 71.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
908382a1c09ee635301d73843c4bdfb93c6ca07abf2263c7ee1825e377abac29
|
|
| MD5 |
31899f6187746adcf7fe406e24364772
|
|
| BLAKE2b-256 |
468ff0dbb6c2f349ac0b50c09702024f838f0010f6ad2357fefa7856369dcc02
|