Skip to main content

Atalaya is a logger for pytorch.

Project description

Atalaya

This framework provides a logger for pytorch models, it allows you to save the parameters, the state of the network, the state of the optimizer and allows also to visualize your data using tensorboardX or visdom.

Install

$ pip install atalaya

Examples

Examples are provided in examples directory, where we simply add the logger to an example of a pytorch implemetation (source) in example_1. In each directory you have also the files created by the logger. There is a directory named logs and one named vizualize. The first one contains the logs of each experiment and the second one the files needed to visualize e.g. in tensorboard.

Usage

Init

from atalaya import Logger

logger = Logger()

# by default Logger uses no grapher
# you can setup it by specifying if you want visdom or tensorboardX
logger = Logger(grapher='visdom')

    close(self)
        """Close the grapher."""




    save(self)
        """Saves the grapher."""

Log Information

    info(self, *argv)
        """Adds an info to the logging file."""
    warning(self, *argv)
        """Adds a warning to the logging file."""

Store your Parameters

    add_parameters(self, params)
        """Adds parameters."""
    restore_parameters(self, path)
        """Loads the parameters of a previous experience given by path"""

Store and Restore (models and optimizers)

  1. Add the model (or optimizer or whatever that has a state_dict in pytorch)

        add(self, name, obj, overwrite=False)
            """Adds an object to the state (dictionary)."""
    
  2. Store the model

        store(self, loss, save_every=1, overwrite=True)
            """Checks if we have to store or if the current model is the best. 
            If it is the case save the best and return True."""
    
  3. Restore the model

        restore(self, folder=None, best=False)
            """Loads a state using torch.load()"""
    

Grapher

    add_scalar(self, tag, scalar_value, global_step=None, save_csv=True)
        """Adds a scalar to the grapher."""

    add_scalars(self, main_tag, tag_scalar_dict, global_step=None)
        """Adds scalars to the grapher."""

    export_scalars_to_json(self, path)
        """Exports scalars to json"""

    add_histogram(self, tag, values, global_step=None, bins='tensorflow')
        """Add histogram to summary."""

    add_image(self, tag, img_tensor, global_step=None, caption=None)
        """Add image data to summary."""

    add_figure(self, tag, figure, global_step=None, close=True)
        """Render matplotlib figure into an image and add it to summary."""

    add_video(self, tag, vid_tensor, global_step=None, fps=4)
        """Add video data to summary."""

    add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100)
        """Add audio data to summary."""

    add_text(self, tag, text_string, global_step=None)
        """Add text data to summary."""

    add_graph_onnx(self, prototxt)
        self.grapher.add_graph_onnx(prototxt)

    add_graph(self, model, input_to_model=None, verbose=False, **kwargs)
        """Adds a graph to the grapher."""

    add_embedding(self, mat, metadata=None, label_img=None,
                      global_step=None, tag='default', metadata_header=None)
        """Adds an embedding to the grapher."""

    add_pr_curve(self, tag, labels, predictions, global_step=None,
                     num_thresholds=127, weights=None)
        """Adds precision recall curve."""

    add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall, 
                         global_step=None, num_thresholds=127, weights=None)
        """Adds precision recall curve with raw data."""

    register_plots(self, values, epoch, prefix, apply_mean=True, 
                       save_csv=True, info=True)
        """Helper to register a  dictionary with multiple list of scalars.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

atalaya-0.1.5.0.tar.gz (12.0 kB view hashes)

Uploaded source

Built Distribution

atalaya-0.1.5.0-py3-none-any.whl (13.0 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page