Skip to main content

PyTorch utilities for developing deep learning frameworks

Project description

torchutil

PyPI License Downloads

General utilities for developing deep learning projects in PyTorch

pip install torchutil

Table of contents

Checkpoint

import torch
import torchutil

# Checkpoint location
file = 'model.pt'

# Initialize PyTorch model
model = torch.nn.Sequential(torch.nn.Conv1d())

# Initialize optimizer
optimizer = torch.nn.Adam(model.parameters())

# Save
torchutil.checkpoint.save(file, model, optimizer, step=0, epoch=0)

# Load for training
model, optimizer, state = torchutil.checkpoint.load(file, model, optimizer)
step, epoch = state['step'], state['epoch']

# Load for inference
model, *_ = torchutil.checkpoint.load(file, model, optimizer)

torchutil.checkpoint.latest_path

def latest_path(
        directory: Union[str, bytes, os.PathLike],
        regex: str = '*.pt',
        latest_fn: Callable = highest_number,
    ) -> Union[str, bytes, os.PathLike]:
    """Retrieve the path to the most recent checkpoint in a directory

    Arguments
        directory - The directory to search for checkpoint files
        regex - The regular expression matching checkpoints
        latest_fn - Takes a list of checkpoint paths and returns the latest.
                    Default assumes checkpoint names are training step count.

    Returns
        The latest checkpoint in directory according to latest_fn
    """

torchutil.checkpoint.load

def load(
    file: Union[str, bytes, os.PathLike],
    model: torch.nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    map_location: str = 'cpu') -> Tuple[
        torch.nn.Module,
        Union[None, torch.optim.Optimizer],
        Dict
    ]:
    """Load model checkpoint

    Arguments
        file - The checkpoint file
        model - The PyTorch model
        optimizer - Optional PyTorch optimizer for training
        map_location - The device to load the checkpoint on

    Returns
        model - The model with restored weights
        optimizer - Optional optimizer with restored parameters
        state - Additional values that the user defined during save
    """

torchutil.checkpoint.save

def save(
    file: Union[str, bytes, os.PathLike],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    accelerator: Optional[accelerate.Accelerator] = None,
    **kwargs):
    """Save training checkpoint to disk

    Arguments
        file - The checkpoint file
        model - The PyTorch model
        optimizer - The PyTorch optimizer
        accelerator - HuggingFace Accelerator for device management
        kwargs - Additional values to save
    """

Download

torchutil.download.file

def file(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download file from url

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.tarbz2

def tarbz2(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract tar bz2 file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.targz

def targz(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract tar gz file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

torchutil.download.zip

def zip(url: 'str', path: Union[str, bytes, os.PathLike]):
    """Download and extract zip file to location

    Arguments
        url - The URL to download
        path - The location to save results
    """

Notify

To use the torchutil notification system, set the PYTORCH_NOTIFICATION_URL environment variable to a supported webhook as explained in the Apprise documentation.

import torchutil

# Send notification when function returns
@torchutil.notify.on_return('train')
def train():
    ...

# Equivalent using context manager
def train():
    with torchutil.notify.on_exit('train'):
        ...

torchutil.notify.on_exit

@contextlib.contextmanager
def on_finish(
    description: str,
    track_time: bool = True,
    notify_on_fail: bool = True):
    """Context manager for sending job notifications

    Arguments
        description - The name of the job being run
        track_time - Whether to report time elapsed
        notify_on_fail - Whether to send a notification on failure
    """

torchutil.notify.on_return

def on_return(
    description: str,
    track_time: bool = True,
    notify_on_fail: bool = True) -> Callable:
    """Decorator for sending job notifications

    Arguments
        description - The name of the job being run
        track_time - Whether to report time elapsed
        notify_on_fail - Whether to send a notification on failure
    """

Tensorboard

import matplotlib.pyplot as plt
import torch
import torchutil

# Directory to write Tensorboard files
directory = 'tensorboard'

# Training step
step = 0

# Example audio
audio = torch.zeros(1, 16000)
sample_rate = 16000

# Example figure
figure = plt.figure()
plt.plot([0, 1, 2, 3])

# Example image
image = torch.zeros(256, 256, 3)

# Example scalar
loss = 0

# Update Tensorboard
torchutil.tensorboard.update(
    directory,
    step,
    audio={'audio': audio},
    sample_rate=sample_rate,
    figures={'figure': figure},
    images={'image': image},
    scalars={'loss': loss})

torchutil.tensorboard.update

def update(
    directory: Union[str, bytes, os.PathLike],
    step: int,
    audio: Optional[Dict[str, torch.Tensor]] = None,
    sample_rate: Optional[int] = None,
    figures: Optional[Dict[str, matplotlib.figure.Figure]] = None,
    images: Optional[Dict[str, torch.Tensor]] = None,
    scalars: Optional[Dict[str, Union[float, int, torch.Tensor]]] = None):
    """Update Tensorboard

    Arguments
        directory - Directory to write Tensorboard files
        step - Training step
        audio - Optional dictionary of 2D audio tensors to monitor
        sample_rate - Audio sample rate; required if audio is not None
        figures - Optional dictionary of Matplotlib figures to monitor
        images - Optional dictionary of 3D image tensors to monitor
        scalars - Optional dictionary of scalars to monitor
    """

Time

import time

# Perform timing
with torchutil.time.context('outer'):
    time.sleep(1)
    for i in range(2):
        time.sleep(1)
        with torchutil.time.context('inner'):
            time.sleep(1)

# Prints {'outer': TODO, 'inner': TODO}
print(torchutil.timer.results())

torchutil.time.context

@contextlib.contextmanager
def context(name: str):
    """Wrapper to handle context changes of global timer

    Arguments
        name - Name of the timer to add time to
    """

torchutil.time.results

def results() -> dict:
    """Get timing results

    Returns
        Timing results: {name: elapsed_time} for all names
    """

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchutil-0.0.1.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchutil-0.0.1-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file torchutil-0.0.1.tar.gz.

File metadata

  • Download URL: torchutil-0.0.1.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for torchutil-0.0.1.tar.gz
Algorithm Hash digest
SHA256 75631ee98078bf69e54948fcf108e607deeb76606a3f09cb1d44bef5f1b86510
MD5 309c82945550ffe3e0758fb2cee2469e
BLAKE2b-256 ab8d7104724d7dc434bc78793fab99229d7e10c5cd1c56907ac9bf3477f08e01

See more details on using hashes here.

File details

Details for the file torchutil-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torchutil-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.17

File hashes

Hashes for torchutil-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4db13ff6912a756054f0145f79a23f1cc5c1ed406c71c1b2ca043afc3eac9e49
MD5 bb6cc44b83766665a7e7c61e9b3f1b43
BLAKE2b-256 9e6a13973ea1795dce62dc1b920558bca2b9dbd870953f083283ddf8f377ded7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page