Skip to main content

PyTorch utilities for developing deep learning frameworks

Project description

torchutil

PyPI License Downloads

General utilities for developing deep learning projects in PyTorch

pip install torchutil

Table of contents

Checkpoint

import torch
import torchutil

# Checkpoint location
file = 'model.pt'

# Initialize PyTorch model
model = torch.nn.Sequential(torch.nn.Conv1d())

# Initialize optimizer
optimizer = torch.nn.Adam(model.parameters())

# Save
torchutil.checkpoint.save(file, model, optimizer, step=0, epoch=0)

# Load for training
model, optimizer, state = torchutil.checkpoint.load(file, model, optimizer)
step, epoch = state['step'], state['epoch']

# Load for inference
model, *_ = torchutil.checkpoint.load(file, model, optimizer)

torchutil.checkpoint.latest_path

def latest_path(
        directory: Union[str, bytes, os.PathLike],
        regex: str = '*.pt',
        latest_fn: Callable = highest_number,
    ) -> Union[str, bytes, os.PathLike]:
    """Retrieve the path to the most recent checkpoint in a directory

    Arguments
        directory - The directory to search for checkpoint files
        regex - The regular expression matching checkpoints
        latest_fn - Takes a list of checkpoint paths and returns the latest.
                    Default assumes checkpoint names are training step count.

    Returns
        The latest checkpoint in directory according to latest_fn
    """

torchutil.checkpoint.load

def load(
    file: Union[str, bytes, os.PathLike],
    model: torch.nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    map_location: str = 'cpu') -> Tuple[
        torch.nn.Module,
        Union[None, torch.optim.Optimizer],
        Dict
    ]:
    """Load model checkpoint

    Arguments
        file - The checkpoint file
        model - The PyTorch model
        optimizer - Optional PyTorch optimizer for training
        map_location - The device to load the checkpoint on

    Returns
        model - The model with restored weights
        optimizer - Optional optimizer with restored parameters
        state - Additional values that the user defined during save
    """

torchutil.checkpoint.save

def save(
    file: Union[str, bytes, os.PathLike],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    accelerator: Optional[accelerate.Accelerator] = None,
    **kwargs):
    """Save training checkpoint to disk

    Arguments
        file - The checkpoint file
        model - The PyTorch model
        optimizer - The PyTorch optimizer
        accelerator - HuggingFace Accelerator for device management
        kwargs - Additional values to save
    """

Download

torchutil.download.file

def file(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download file from url

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.tarbz2

def tarbz2(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract tar bz2 file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.targz

def targz(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract tar gz file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.zip

def zip(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract zip file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

Notify

To use the torchutil notification system, set the PYTORCH_NOTIFICATION_URL environment variable to a supported webhook as explained in the Apprise documentation.

import torchutil

# Send notification when function returns
@torchutil.notify.on_return('train')
def train():
    ...

# Equivalent using context manager
def train():
    with torchutil.notify.on_exit('train'):
        ...

torchutil.notify.on_exit

@contextlib.contextmanager
def on_finish(
    description: str,
    track_time: bool = True,
    notify_on_fail: bool = True):
    """Context manager for sending job notifications

    Arguments
        description - The name of the job being run
        track_time - Whether to report time elapsed
        notify_on_fail - Whether to send a notification on failure
    """

torchutil.notify.on_return

def on_return(
    description: str,
    track_time: bool = True,
    notify_on_fail: bool = True) -> Callable:
    """Decorator for sending job notifications

    Arguments
        description - The name of the job being run
        track_time - Whether to report time elapsed
        notify_on_fail - Whether to send a notification on failure
    """

Tensorboard

import matplotlib.pyplot as plt
import torch
import torchutil

# Directory to write Tensorboard files
directory = 'tensorboard'

# Training step
step = 0

# Example audio
audio = torch.zeros(1, 16000)
sample_rate = 16000

# Example figure
figure = plt.figure()
plt.plot([0, 1, 2, 3])

# Example image
image = torch.zeros(256, 256, 3)

# Example scalar
loss = 0

# Update Tensorboard
torchutil.tensorboard.update(
    directory,
    step,
    audio={'audio': audio},
    sample_rate=sample_rate,
    figures={'figure': figure},
    images={'image': image},
    scalars={'loss': loss})

torchutil.tensorboard.update

def update(
    directory: Union[str, bytes, os.PathLike],
    step: int,
    audio: Optional[Dict[str, torch.Tensor]] = None,
    sample_rate: Optional[int] = None,
    figures: Optional[Dict[str, matplotlib.figure.Figure]] = None,
    images: Optional[Dict[str, torch.Tensor]] = None,
    scalars: Optional[Dict[str, Union[float, int, torch.Tensor]]] = None):
    """Update Tensorboard

    Arguments
        directory - Directory to write Tensorboard files
        step - Training step
        audio - Optional dictionary of 2D audio tensors to monitor
        sample_rate - Audio sample rate; required if audio is not None
        figures - Optional dictionary of Matplotlib figures to monitor
        images - Optional dictionary of 3D image tensors to monitor
        scalars - Optional dictionary of scalars to monitor
    """

Time

import time

# Perform timing
with torchutil.time.context('outer'):
    time.sleep(1)
    for i in range(2):
        time.sleep(1)
        with torchutil.time.context('inner'):
            time.sleep(1)

# Prints {'outer': TODO, 'inner': TODO}
print(torchutil.timer.results())

torchutil.time.context

@contextlib.contextmanager
def context(name: str):
    """Wrapper to handle context changes of global timer

    Arguments
        name - Name of the timer to add time to
    """

torchutil.time.results

def results() -> dict:
    """Get timing results

    Returns
        Timing results: {name: elapsed_time} for all names
    """

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchutil-0.0.1.tar.gz (8.2 kB view hashes)

Uploaded Source

Built Distribution

torchutil-0.0.1-py3-none-any.whl (9.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page