Skip to main content

Real-time learning curve for Jupiter notebooks

Project description

lrcurve

Creates a learning-curve plot for Jupyter/Colab notebooks that is updated in real-time.

There is a framework agnostic interface lrcurve.PlotLearningCurve that works well with PyTorch and Tensorflow and a keras wrapper lrcurve.KerasLearningCurve that uses the keras callback interface.

lrcurve works with python 3.6 or newer and is distributed under the MIT license.

Gif of learning-curve

Install

pip install -U lrcurve

API

Examples

Keras example

Open In Colab

from lrcurve import KerasLearningCurve

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])

model.fit(train.x, train.y,
          epochs=100,
          verbose=0,
          validation_data=(validation.x, validation.y),
          callbacks=[KerasLearningCurve()])

Gif of learning-curve for keras example

Framework agnostic example

Open In Colab

with PlotLearningCurve() as plot:
    for i in range(100):
        plot.append(i, {
            'loss': math.exp(-(i+1)/10),
            'val_loss': math.exp(-i/10)
        })
        plot.draw()
        time.sleep(0.1)

Gif of learning-curve for simple example

PyTorch example

Open In Colab

from lrcurve import PlotLearningCurve

plot = PlotLearningCurve(
    mappings = {
        'loss': { 'line': 'train', 'facet': 'loss' },
        'val_loss': { 'line': 'validation', 'facet': 'loss' },
        'acc': { 'line': 'train', 'facet': 'acc' },
        'val_acc': { 'line': 'validation', 'facet': 'acc' }
    },
    facet_config = {
        'loss': { 'name': 'Cross-Entropy', 'limit': [0, None], 'scale': 'linear' },
        'acc': { 'name': 'Accuracy', 'limit': [0, 1], 'scale': 'linear' }
    },
    xaxis_config = { 'name': 'Epoch', 'limit': [0, 500] }
)

with plot:
    # optimize model
    for epoch in range(500):
        # compute loss
        z_test = network(x_test)
        loss_test = criterion(z_test, y_test)

        optimizer.zero_grad()
        z_train = network(x_train)
        loss_train = criterion(z_train, y_train)
        loss_train.backward()
        optimizer.step()

        # compute accuacy
        accuacy_test = sklearn.metrics.accuracy_score(torch.argmax(z_test, 1).numpy(), y_test)
        accuacy_train = sklearn.metrics.accuracy_score(torch.argmax(z_train, 1).numpy(), y_train)

        # append and update
        plot.append(epoch, {
            'loss': loss_train,
            'val_loss': loss_test,
            'acc': accuacy_train,
            'val_acc': accuacy_test
        })
        plot.draw()

Gif of learning-curve for pytorch example

Sponsor

Sponsored by NearForm Research.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lrcurve-2.2.1.tar.gz (70.2 kB view details)

Uploaded Source

Built Distribution

lrcurve-2.2.1-py3-none-any.whl (71.0 kB view details)

Uploaded Python 3

File details

Details for the file lrcurve-2.2.1.tar.gz.

File metadata

  • Download URL: lrcurve-2.2.1.tar.gz
  • Upload date:
  • Size: 70.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8

File hashes

Hashes for lrcurve-2.2.1.tar.gz
Algorithm Hash digest
SHA256 60d490123d7cd699e3a57693f35bd88652534eb982813b10bf8ae8d82ec6a0dc
MD5 986733c8f45de37c31b0ac805f457083
BLAKE2b-256 daef5fd55fca7607086803541861c1cb8c079932e8446e7bc9390ef47cfd3c71

See more details on using hashes here.

File details

Details for the file lrcurve-2.2.1-py3-none-any.whl.

File metadata

  • Download URL: lrcurve-2.2.1-py3-none-any.whl
  • Upload date:
  • Size: 71.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.25.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.8

File hashes

Hashes for lrcurve-2.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 908382a1c09ee635301d73843c4bdfb93c6ca07abf2263c7ee1825e377abac29
MD5 31899f6187746adcf7fe406e24364772
BLAKE2b-256 468ff0dbb6c2f349ac0b50c09702024f838f0010f6ad2357fefa7856369dcc02

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page