Some utility functions for working with PyTorch.
Project description
cjm-pytorch-utils
Install
pip install cjm_pytorch_utils
How to use
pil_to_tensor
from cjm_pytorch_utils.core import pil_to_tensor
from PIL import Image
from torchvision import transforms
img_path = img_path = '../images/cat.jpg'
src_img = Image.open(img_path).convert('RGB')
print(f"Source Image Size: {src_img.size}")
img_tensor = pil_to_tensor(src_img, [0.5], [0.5])
img_tensor.shape, img_tensor.min(), img_tensor.max()
Source Image Size: (768, 512)
(torch.Size([1, 3, 512, 768]), tensor(-1.), tensor(1.))
tensor_to_pil
from cjm_pytorch_utils.core import tensor_to_pil
tensor_img = tensor_to_pil(transforms.ToTensor()(src_img))
tensor_img
iterate_modules
from cjm_pytorch_utils.core import iterate_modules
import torch
from torchvision import models
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
for index, module in enumerate(iterate_modules(vgg)):
if type(module) == torch.nn.modules.activation.ReLU:
print(f"{index}: {module}")
1: ReLU(inplace=True)
3: ReLU(inplace=True)
6: ReLU(inplace=True)
8: ReLU(inplace=True)
11: ReLU(inplace=True)
13: ReLU(inplace=True)
15: ReLU(inplace=True)
18: ReLU(inplace=True)
20: ReLU(inplace=True)
22: ReLU(inplace=True)
25: ReLU(inplace=True)
27: ReLU(inplace=True)
29: ReLU(inplace=True)
tensor_stats_df
from cjm_pytorch_utils.core import tensor_stats_df
tensor_stats_df(torch.randn(1, 3, 256, 256))
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
0 | |
---|---|
mean | 0.001601 |
std | 0.999375 |
min | -4.79662 |
max | 4.263451 |
shape | (1, 3, 256, 256) |
get_torch_device
from cjm_pytorch_utils.core import get_torch_device
get_torch_device()
'cuda'
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for cjm_pytorch_utils-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8da5138471d3916e5e0037c0818815c686048cd21d30be97fc270da529eae88a |
|
MD5 | 2d615ab776abd5d6fbfa295efa056cf6 |
|
BLAKE2b-256 | c97d7b0253aefeedf398b9c0f29f636a916ac5a4f97550b492ec5f793c022f09 |