Common torch tools and extension
Project description
pytorch-common
A Pypi module with pytorch common tools like:
Build release
Step 1: Increase version into next files:
pytorch_common/__init__.py
pyproject.toml
Step 2: Build release.
$ poetry build ✔
Building pytorch-common (0.2.3)
- Building sdist
- Built pytorch-common-0.2.3.tar.gz
- Building wheel
- Built pytorch_common-0.2.3-py3-none-any.whl
Step 3: Publish release to PyPI repository.
$ poetry publish ✔
Publishing pytorch-common (0.2.3) to PyPI
- Uploading pytorch-common-0.2.3.tar.gz 100%
- Uploading pytorch_common-0.2.3-py3-none-any.whl 100%
Features
- Callbacks (keras style)
- Validation: Model validation.
- ReduceLROnPlateau:
- Reduce learning rate when a metric has stopped improving.
- Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced.
- EarlyStop:
- Stop training when model has stopped improving a specified metric.
- SaveBestModel:
- Save model weights to file while model validation metric improve.
- Logger:
- Logs context properties.
- In general is used to log performance metrics every n epochs.
- MetricsPlotter:
- Plot evaluation metrics.
- This graph is updated every n epochs during training process.
- Allow save plot into a file.
- Callback and OutputCallback:
- Base classes.
- CallbackManager:
- Simplify callbacks support to fit custom models.
- StratifiedKFoldCV:
- Support parallel fold processing on CPU.
- Mixins
FiMixin
CommonMixin
PredictMixin
PersistentMixin
- Utils
- device management
- stopwatch
- data split
- os
- model
- LoggerBuilder
- Dict Utils
WeightsFileResolver
- Plot
- plot promotives like
plot_loss
.
- plot promotives like
Examples
Device management
import pytorch_common.util as pu
# Setup prefered device.
pu.set_device_name('gpu') # / 'cpu'
# Setup GPU memory fraction for a process (%).
pu.set_device_memory(
'gpu' # / 'cpu',
process_memory_fraction=0.5
)
# Get prefered device.
# Note: In case the preferred device is not found, it returns CPU as fallback.
device = pu.get_device()
Logging
import logging
import pytorch_common.util as pu
## Default loggin in console...
pu.LoggerBuilder() \
.on_console() \
.build()
## Setup format and level...
pu.LoggerBuilder() \
.level(logging.ERROR) \
.on_console('%(asctime)s - %(levelname)s - %(message)s') \
.build()
Stopwatch
import logging
import pytorch_common.util as pu
sw = pu.Stopwatch()
# Call any demanding process...
# Get resposne time.
resposne_time = sw.elapsed_time()
# Log resposne time.
logging.info(sw.to_str())
Dataset split
import pytorch_common.util as pu
dataset = ... # <-- Torch.utils.data.Dataset
train_subset, test_subset = pu.train_val_split(
dataset,
train_percent = .7
)
train_subset, val_subset, test_subset = pu.train_val_test_split(
dataset,
train_percent = .7,
val_percent = .15
)
Kfolding
import logging
from pytorch_common.kfoldcv import StratifiedKFoldCV, \
ParallelKFoldCVStrategy, \
NonParallelKFoldCVStrategy
# Call your model under this function..
def train_fold_fn(dataset, train_idx, val_idx, params, fold):
pass
# Get dataset labels
def get_y_values_fn(dataset):
pass
cv = StratifiedKFoldCV(
train_fold_fn,
get_y_values_fn,
strategy=NonParallelKFoldCVStrategy() # or ParallelKFoldCVStrategy()
k_fold = 5
)
# Model hyperparams...
params = {
'seed': 42,
'lr': 0.01,
'epochs': 50,
'batch_size': 4000,
...
}
# Train model...
result = cv.train(dataset, params)
logging.info('CV results: {}'.format(result))
Assertions
from pytorch_common.error import Assertions, Checker
# Check functions and construtor params usign assertions..
param_value = -1
# Raise an exception with 404103 eror code when the condition is not met
Assertions.positive_int(404103, param_value, 'param name')
Assertions.positive_float(404103, param_value, 'param name')
# Other options
Assertions.is_class(404205, param_value, 'param name', aClass)
Assertions.is_tensor(404401, param_value, 'param name')
Assertions.has_shape(404401, param_value, (3, 4), 'param name')
# Assertions was impelemented using a Checker builder:
Checker(error_code, value, name) \
.is_not_none() \
.is_int() \
.is_positive() \
.check()
# Other checker options..
# .is_not_none()
# .is_int()
# .is_float()
# .is_positive()
# .is_a(aclass)
# .is_tensor()
# .has_shape(shape)
Callbacks
from pytorch_common.callbacks import CallbackManager
from pytorch_common.modules import FitContextFactory
from pytorch_common.callbacks import EarlyStop, \
ReduceLROnPlateau, \
Validation
from pytorch_common.callbacks.output import Logger, \
MetricsPlotter
def train_method(model, epochs, optimizer, loss_fn, callbacks):
callback_manager = CallbackManager(
ctx = FitContextFactory.create(model, loss_fn, epochs, optimizer),
callbacks = callbacks
)
for epoch in range(epochs):
callback_manager.on_epoch_start(epoch)
# train model...
callback_manager.on_epoch_end(train_loss)
if callback_manager.break_training():
break
return callback_manager.ctx
model = # Create my model...
optimizer = # My optimizer...
loss_fn = # my lost function
callbacks = [
# Log context variables after each epoch...
Logger(['fold', 'time', 'epoch', 'lr', 'train_loss', 'val_loss', ... ]),
EarlyStop(metric='val_auc', mode='max', patience=3),
ReduceLROnPlateau(metric='val_auc'),
Validation(
val_set,
metrics = {
'my_metric_name': lambda y_pred, y_true: # calculate validation metic,
...
},
each_n_epochs=5
),
SaveBestModel(metric='val_loss'),
MetricsPlotter(metrics=['train_loss', 'val_loss'])
]
train_method(model, epochs=100, optimizer, loss_fn, callbacks)
Utils
WeightsFileResolver
$ ls ./wegiths
2023-08-21_15-17-49--gfm--epoch_2--val_loss_1.877971887588501.pt
2023-08-21_15-13-09--gfm--epoch_3--val_loss_1.8183038234710693.pt
2023-08-19_20-00-19--gfm--epoch_10--val_loss_0.9969356060028076.pt
2023-08-19_19-59-39--gfm--epoch_4--val_loss_1.4990438222885132.pt
import pytorch_common.util as pu
resolver = pu.WeightsFileResolver('./weights')
file_path = resolver(experiment='gfm', metric='val_loss', min_value=True)
print(file_path)
'./weights/2023-08-19_20-00-19--gfm--epoch_10--val_loss_0.9969356060028076.pt'
Go to next projects to see funcional code examples:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch_common-0.3.1.tar.gz
(35.4 kB
view details)
Built Distribution
File details
Details for the file pytorch_common-0.3.1.tar.gz
.
File metadata
- Download URL: pytorch_common-0.3.1.tar.gz
- Upload date:
- Size: 35.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.6.0 CPython/3.11.4 Linux/6.4.9-1-MANJARO
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 69aa6d091408cba23f193b8561698c157cc3c3bb2a3c437cc5548e523c5b796d |
|
MD5 | da1d48025e6a41c3e93d8f0f75910619 |
|
BLAKE2b-256 | 343812c503ade871a1c57fd086ff63c37d43a020ffdb8baff21f92c609224adf |
File details
Details for the file pytorch_common-0.3.1-py3-none-any.whl
.
File metadata
- Download URL: pytorch_common-0.3.1-py3-none-any.whl
- Upload date:
- Size: 26.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.6.0 CPython/3.11.4 Linux/6.4.9-1-MANJARO
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dcbbd71ff13f3b1185eb802d8f87fbc83a52e85ba3bc1fcbe551b35aee583626 |
|
MD5 | 61aa9e19c3a2438e8571e0e5648ff213 |
|
BLAKE2b-256 | 219f34a1cb0f723a89ff74499ee1fa81f89592359990586e7113b1a62a9bca93 |