PyTorch utilities for developing deep learning frameworks
Project description
torchutil
Table of contents
Checkpoint
import torch
import torchutil
# Checkpoint location
file = 'model.pt'
# Initialize PyTorch model
model = torch.nn.Sequential(torch.nn.Conv1d())
# Initialize optimizer
optimizer = torch.nn.Adam(model.parameters())
# Save
torchutil.checkpoint.save(file, model, optimizer, step=0, epoch=0)
# Load for training
model, optimizer, state = torchutil.checkpoint.load(file, model, optimizer)
step, epoch = state['step'], state['epoch']
# Load for inference
model, *_ = torchutil.checkpoint.load(file, model, optimizer)
torchutil.checkpoint.latest_path
def latest_path(
directory: Union[str, bytes, os.PathLike],
regex: str = '*.pt',
latest_fn: Callable = highest_number,
) -> Union[str, bytes, os.PathLike]:
"""Retrieve the path to the most recent checkpoint in a directory
Arguments
directory - The directory to search for checkpoint files
regex - The regular expression matching checkpoints
latest_fn - Takes a list of checkpoint paths and returns the latest.
Default assumes checkpoint names are training step count.
Returns
The latest checkpoint in directory according to latest_fn
"""
torchutil.checkpoint.load
def load(
file: Union[str, bytes, os.PathLike],
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
map_location: str = 'cpu') -> Tuple[
torch.nn.Module,
Union[None, torch.optim.Optimizer],
Dict
]:
"""Load model checkpoint
Arguments
file - The checkpoint file
model - The PyTorch model
optimizer - Optional PyTorch optimizer for training
map_location - The device to load the checkpoint on
Returns
model - The model with restored weights
optimizer - Optional optimizer with restored parameters
state - Additional values that the user defined during save
"""
torchutil.checkpoint.save
def save(
file: Union[str, bytes, os.PathLike],
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
accelerator: Optional[accelerate.Accelerator] = None,
**kwargs):
"""Save training checkpoint to disk
Arguments
file - The checkpoint file
model - The PyTorch model
optimizer - The PyTorch optimizer
accelerator - HuggingFace Accelerator for device management
kwargs - Additional values to save
"""
Download
torchutil.download.file
def file(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download file from url
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.tarbz2
def tarbz2(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract tar bz2 file to location
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.targz
def targz(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract tar gz file to location
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.zip
def zip(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract zip file to location
Arguments
url - The URL to download
path - The location to save results
"""
Notify
To use the torchutil notification system, set the PYTORCH_NOTIFICATION_URL
environment variable to a supported webhook as explained in
the Apprise documentation.
import torchutil
# Send notification when function returns
@torchutil.notify.on_return('train')
def train():
...
# Equivalent using context manager
def train():
with torchutil.notify.on_exit('train'):
...
torchutil.notify.on_exit
@contextlib.contextmanager
def on_finish(
description: str,
track_time: bool = True,
notify_on_fail: bool = True):
"""Context manager for sending job notifications
Arguments
description - The name of the job being run
track_time - Whether to report time elapsed
notify_on_fail - Whether to send a notification on failure
"""
torchutil.notify.on_return
def on_return(
description: str,
track_time: bool = True,
notify_on_fail: bool = True) -> Callable:
"""Decorator for sending job notifications
Arguments
description - The name of the job being run
track_time - Whether to report time elapsed
notify_on_fail - Whether to send a notification on failure
"""
Tensorboard
import matplotlib.pyplot as plt
import torch
import torchutil
# Directory to write Tensorboard files
directory = 'tensorboard'
# Training step
step = 0
# Example audio
audio = torch.zeros(1, 16000)
sample_rate = 16000
# Example figure
figure = plt.figure()
plt.plot([0, 1, 2, 3])
# Example image
image = torch.zeros(256, 256, 3)
# Example scalar
loss = 0
# Update Tensorboard
torchutil.tensorboard.update(
directory,
step,
audio={'audio': audio},
sample_rate=sample_rate,
figures={'figure': figure},
images={'image': image},
scalars={'loss': loss})
torchutil.tensorboard.update
def update(
directory: Union[str, bytes, os.PathLike],
step: int,
audio: Optional[Dict[str, torch.Tensor]] = None,
sample_rate: Optional[int] = None,
figures: Optional[Dict[str, matplotlib.figure.Figure]] = None,
images: Optional[Dict[str, torch.Tensor]] = None,
scalars: Optional[Dict[str, Union[float, int, torch.Tensor]]] = None):
"""Update Tensorboard
Arguments
directory - Directory to write Tensorboard files
step - Training step
audio - Optional dictionary of 2D audio tensors to monitor
sample_rate - Audio sample rate; required if audio is not None
figures - Optional dictionary of Matplotlib figures to monitor
images - Optional dictionary of 3D image tensors to monitor
scalars - Optional dictionary of scalars to monitor
"""
Time
import time
# Perform timing
with torchutil.time.context('outer'):
time.sleep(1)
for i in range(2):
time.sleep(1)
with torchutil.time.context('inner'):
time.sleep(1)
# Prints {'outer': TODO, 'inner': TODO}
print(torchutil.timer.results())
torchutil.time.context
@contextlib.contextmanager
def context(name: str):
"""Wrapper to handle context changes of global timer
Arguments
name - Name of the timer to add time to
"""
torchutil.time.results
def results() -> dict:
"""Get timing results
Returns
Timing results: {name: elapsed_time} for all names
"""
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchutil-0.0.1.tar.gz
(8.2 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchutil-0.0.1.tar.gz.
File metadata
- Download URL: torchutil-0.0.1.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
75631ee98078bf69e54948fcf108e607deeb76606a3f09cb1d44bef5f1b86510
|
|
| MD5 |
309c82945550ffe3e0758fb2cee2469e
|
|
| BLAKE2b-256 |
ab8d7104724d7dc434bc78793fab99229d7e10c5cd1c56907ac9bf3477f08e01
|
File details
Details for the file torchutil-0.0.1-py3-none-any.whl.
File metadata
- Download URL: torchutil-0.0.1-py3-none-any.whl
- Upload date:
- Size: 9.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4db13ff6912a756054f0145f79a23f1cc5c1ed406c71c1b2ca043afc3eac9e49
|
|
| MD5 |
bb6cc44b83766665a7e7c61e9b3f1b43
|
|
| BLAKE2b-256 |
9e6a13973ea1795dce62dc1b920558bca2b9dbd870953f083283ddf8f377ded7
|