PyTorch utilities for developing deep learning frameworks
Project description
torchutil
Table of contents
Checkpoint
import torch
import torchutil
# Checkpoint location
file = 'model.pt'
# Initialize PyTorch model
model = torch.nn.Sequential(torch.nn.Conv1d())
# Initialize optimizer
optimizer = torch.nn.Adam(model.parameters())
# Save
torchutil.checkpoint.save(file, model, optimizer, step=0, epoch=0)
# Load for training
model, optimizer, state = torchutil.checkpoint.load(file, model, optimizer)
step, epoch = state['step'], state['epoch']
# Load for inference
model, *_ = torchutil.checkpoint.load(file, model, optimizer)
torchutil.checkpoint.best_path
def best_path(
directory: Union[str, bytes, os.PathLike],
glob: str = '*.pt',
best_fn: Callable = highest_score
) -> Tuple[Union[str, bytes, os.PathLike], float]:
"""Retrieve the path to the best checkpoint
Arguments
directory - The directory to search for checkpoint files
glob - The regular expression matching checkpoints
best_fn - Takes a list of checkpoint paths and returns the latest
Default assumes checkpoint names are training step count.
Returns
best_file - The filename of the checkpoint with the best score
best_score - The corresponding score
"""
torchutil.checkpoint.latest_path
def latest_path(
directory: Union[str, bytes, os.PathLike],
glob: str = '*.pt',
latest_fn: Callable = largest_number_filename,
) -> Union[str, bytes, os.PathLike]:
"""Retrieve the path to the most recent checkpoint in a directory
Arguments
directory - The directory to search for checkpoint files
glob - The regular expression matching checkpoints
latest_fn - Takes a list of checkpoint paths and returns the latest.
Default assumes checkpoint names are training step count.
Returns
The latest checkpoint in directory according to latest_fn
"""
torchutil.checkpoint.load
def load(
file: Union[str, bytes, os.PathLike],
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
map_location: str = 'cpu') -> Tuple[
torch.nn.Module,
Union[None, torch.optim.Optimizer],
Dict
]:
"""Load model checkpoint
Arguments
file - The checkpoint file
model - The PyTorch model
optimizer - Optional PyTorch optimizer for training
map_location - The device to load the checkpoint on
Returns
model - The model with restored weights
optimizer - Optional optimizer with restored parameters
state - Additional values that the user defined during save
"""
torchutil.checkpoint.save
def save(
file: Union[str, bytes, os.PathLike],
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
accelerator: Optional[accelerate.Accelerator] = None,
**kwargs):
"""Save training checkpoint to disk
Arguments
file - The checkpoint file
model - The PyTorch model
optimizer - The PyTorch optimizer
accelerator - HuggingFace Accelerator for device management
kwargs - Additional values to save
"""
Download
torchutil.download.file
def file(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download file from url
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.tarbz2
def tarbz2(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract tar bz2 file to location
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.targz
def targz(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract tar gz file to location
Arguments
url - The URL to download
path - The location to save results
"""
torchutil.download.zip
def zip(url: 'str', path: Union[str, bytes, os.PathLike]):
"""Download and extract zip file to location
Arguments
url - The URL to download
path - The location to save results
"""
Iterator
def iterator(
iterable: Iterable,
message: Optional[str],
initial: int = 0,
total: Optional[int] = None
) -> Iterable:
"""Create a tqdm iterator
Arguments
iterable
Items to iterate over
message
Static message to display
initial
Position to display corresponding to index zero of iterable
total
Length of the iterable; defaults to len(iterable)
"""
Metrics
import torch
import torchutil
# Define a custom, batch-updating loss metric
class Loss(torchutil.metrics.Average):
def update(self, predicted, target):
# Compute your loss and the number of elements to average over
loss = ...
count = ...
super().update(loss, count)
# Instantiate metrics
loss = Loss()
rmse = torchutil.metrics.RMSE()
# Generator that produces batches of predicted and target tensors
iterable = ...
# Update metrics
for predicted_tensor, target_tensor in iterable:
loss.update(predicted_tensor, target_tensor)
rmse.update(predicted_tensor, target_tensor)
# Get results
print('loss': loss())
print('rmse': rmse())
torchutil.metrics.Accuracy
class Accuracy(Metric):
"""Batch-updating accuracy metric"""
def __call__(self)-> float:
"""Retrieve the current accuracy value
The current accuracy value
"""
def update(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Update accuracy
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset accuracy"""
torchutil.metrics.Average
class Average(Metric):
"""Batch-updating average metric"""
def __call__(self)-> float:
"""Retrieve the current average value
The current average value
"""
def update(self, values: torch.Tensor, count: int) -> None:
"""Update running average
Arguments
values
The values to average
target
The number of values
"""
def reset(self) -> None:
"""Reset running average"""
torchutil.metrics.F1
class F1(Metric):
"""Batch-updating F1 score"""
def __call__(self) -> float:
"""Retrieve the current F1 value
The current F1 value
"""
def update(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Update F1
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset F1"""
torchutil.metrics.L1
class L1(Metric):
"""Batch updating L1 score"""
def __call__(self) -> float:
"""Retrieve the current L1 value
The current L1 value
"""
def update(self, predicted, target) -> None:
"""Update L1
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset L1"""
torchutil.metrics.MeanStd
class MeanStd(Metric):
"""Batch updating mean and standard deviation"""
def __call__(self) -> Tuple[float, float]:
"""Retrieve the current mean and standard deviation
Returns
The current mean and standard deviation
"""
def update(self, values: torch.Tensor) -> None:
"""Update mean and standard deviation
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset mean and standard deviation"""
torchutil.metrics.PearsonCorrelation
class PearsonCorrelation(Metric):
"""Batch-updating Pearson correlation"""
def __init__(
self,
predicted_mean: float,
predicted_std: float,
target_mean: float,
target_std: float
) -> None:
"""
Arguments
predicted_mean - Mean of predicted values
predicted_std - Standard deviation of predicted values
target_mean - Mean of target values
target_std - Standard deviation of target values
"""
def __call__(self) -> float:
"""Retrieve the current correlation value
Returns
The current correlation value
"""
def update(self, predicted, target) -> None:
"""Update Pearson correlation
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset Pearson correlation"""
torchutil.metrics.Precision
class Precision(Metric):
"""Batch-updating precision metric"""
def __call__(self) -> float:
"""Retrieve the current precision value
Returns
The current precision value
"""
def update(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Update precision
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset precision"""
torchutil.metrics.Recall
class Recall(Metric):
"""Batch-updating recall metric"""
def __call__(self) -> float:
"""Retrieve the current recall value
The current recall value
"""
def update(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Update recall
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset recall"""
torchutil.metrics.RMSE
class RMSE(Metric):
"""Batch-updating RMSE metric"""
def __call__(self) -> float:
"""Retrieve the current rmse value
The current rmse value
"""
def update(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Update RMSE
Arguments
predicted
The model prediction
target
The corresponding ground truth
"""
def reset(self) -> None:
"""Reset RMSE"""
Notify
To use the torchutil
notification system, set the PYTORCH_NOTIFICATION_URL
environment variable to a supported webhook as explained in
the Apprise documentation.
import torchutil
# Send notification when function returns
@torchutil.notify('train')
def train():
...
# Equivalent using context manager
def train():
with torchutil.notify('train'):
...
torchutil.notify
@contextlib.contextmanager
def notify(
description: str,
track_time: bool = True,
notify_on_fail: bool = True):
"""Context manager for sending job notifications
Arguments
description - The name of the job being run
track_time - Whether to report time elapsed
notify_on_fail - Whether to send a notification on failure
"""
Paths
torchutil.paths.chdir
@contextlib.contextmanager
def chdir(directory: Union[str, bytes, os.PathLike]) -> None:
"""Context manager for changing the current working directory
Arguments
directory
The desired working directory
"""
This function is both a context manager and decorator.
import tempfile
from pathlib import Path
import torchutil
# Create a directory
directory = tempfile.TemporaryDirectory()
# Create a file
file = 'tmp.txt'
(Path(directory.name) / file).touch()
# File is not in current working directory
assert not Path(file).exists()
# Change working directory using context manager
with torchutil.paths.chdir(directory.name):
assert Path(file).exists()
# File is not in current working directory
assert not Path(file).exists()
# Change working directory using decorator
@torchutil.paths.chdir(directory.name)
def exists(file):
assert Path(file).exists()
exists(file)
# File is not in current working directory
assert not Path(file).exists()
# Remove temporary paths
directory.cleanup()
torchutil.paths.measure
def measure(
globs: Union[str, List[str]],
roots: Optional[
Union[
Union[str, bytes, os.PathLike],
List[Union[str, bytes, os.PathLike]]
]] = None,
recursive: bool = False,
unit='B'
) -> Union[int, float]:
"""Measure data usage of files and directories
Arguments
globs
Globs matching paths to measure
roots
Directories to apply glob searches; current directory by default
recursive
Apply globs to all subdirectories of root directories
unit
Unit of memory utilization (bytes to terabytes); default bytes
"""
This function also has a command-line interface.
python -m torchutil.paths.measure \
[-h] \
--globs GLOBS \
[--roots ROOTS] \
[--recursive] \
[--unit]
Measure data usage of files and directories
arguments:
--globs GLOBS
Globs matching paths to measure
optional arguments:
-h, --help
show this help message and exit
--roots ROOTS
Directories to apply glob searches; current directory by default
--recursive
Apply globs to all subdirectories of root directories
--unit
Unit of memory utilization (bytes to terabytes); default bytes
torchutil.paths.purge
def purge(
globs: Union[str, List[str]],
roots: Optional[
Union[
Union[str, bytes, os.PathLike],
List[Union[str, bytes, os.PathLike]]
]] = None,
recursive: bool = False
) -> None:
"""Remove all files and directories within directory matching glob
Arguments
globs
Globs matching files to delete
roots
Directories to apply glob searches; current directory by default
recursive
Apply globs to all subdirectories of root directories
"""
This function also has a command-line interface.
python -m torchutil.paths.purge \
[-h] \
--globs GLOBS \
[--roots ROOTS] \
[--recursive] \
[--force]
Remove files and directories
arguments:
--globs GLOBS
Globs matching paths to delete
optional arguments:
-h, --help
show this help message and exit
--roots ROOTS
Directories to apply glob searches; current directory by default
--recursive
Apply globs to all subdirectories of root directories
Tensorboard
import matplotlib.pyplot as plt
import torch
import torchutil
# Directory to write Tensorboard files
directory = 'tensorboard'
# Training step
step = 0
# Example audio
audio = torch.zeros(1, 16000)
sample_rate = 16000
# Example figure
figure = plt.figure()
plt.plot([0, 1, 2, 3])
# Example image
image = torch.zeros(256, 256, 3)
# Example scalar
loss = 0
# Update Tensorboard
torchutil.tensorboard.update(
directory,
step,
audio={'audio': audio},
sample_rate=sample_rate,
figures={'figure': figure},
images={'image': image},
scalars={'loss': loss})
torchutil.tensorboard.update
def update(
directory: Union[str, bytes, os.PathLike],
step: int,
audio: Optional[Dict[str, torch.Tensor]] = None,
sample_rate: Optional[int] = None,
figures: Optional[Dict[str, matplotlib.figure.Figure]] = None,
images: Optional[Dict[str, torch.Tensor]] = None,
scalars: Optional[Dict[str, Union[float, int, torch.Tensor]]] = None):
"""Update Tensorboard
Arguments
directory - Directory to write Tensorboard files
step - Training step
audio - Optional dictionary of 2D audio tensors to monitor
sample_rate - Audio sample rate; required if audio is not None
figures - Optional dictionary of Matplotlib figures to monitor
images - Optional dictionary of 3D image tensors to monitor
scalars - Optional dictionary of scalars to monitor
"""
Time
import time
# Perform timing
with torchutil.time.context('outer'):
time.sleep(1)
for i in range(2):
time.sleep(1)
with torchutil.time.context('inner'):
time.sleep(1)
# Prints {'outer': TODO, 'inner': TODO}
print(torchutil.timer.results())
torchutil.time.context
@contextlib.contextmanager
def context(name: str):
"""Wrapper to handle context changes of global timer
Arguments
name - Name of the timer to add time to
"""
torchutil.time.reset
def reset():
"""Clear timer state"""
torchutil.time.results
def results() -> dict:
"""Get timing results
Returns
Timing results: {name: elapsed_time} for all names
"""
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchutil-0.0.8.tar.gz
(18.1 kB
view hashes)
Built Distribution
torchutil-0.0.8-py3-none-any.whl
(18.9 kB
view hashes)
Close
Hashes for torchutil-0.0.8-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1d1a8777b541acb7fde8ba1b57a0f2ff926bba251b50966516d4f1495d0d1c87 |
|
MD5 | f9d38fb8be31748f8ba83cc7e737837c |
|
BLAKE2b-256 | 32671cc336170e665f3e9724e8684d927e29a4c570b4973079f29f5757618b39 |