A small tool for PyTorch training
Project description
🔥 torchtrain 💪
A small tool for PyTorch training.
Features
- Avoid boilerplate code for training.
- Stepwise training.
- Automatic TensorBoard logging, and tqdm bar.
- Count model parameters and save hyperparameters.
- DataParallel.
- Early stop.
- Save and load checkpoint. Continue training.
- Catch out of memory exceptions to avoid breaking training.
- Gradient accumulation.
- Gradient clipping.
- Only run few epochs, steps and batches for code test.
Install
pip install torchtrain
Example
Check doc string of Trainer
class for detailed configurations.
An incomplete minimal example:
data_iter = get_data()
model = Bert()
optimizer = Adam(model.parameters(), lr=cfg["lr"])
criteria = {"loss": AverageAggregator(BCELoss())}
trainer = Trainer(model, data_iter, criteria, cfg, optimizer)
trainer.train(stepwise=True)
Or this version:
from argparse import ArgumentParser
from sklearn.model_selection import ParameterGrid
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoModel, BertTokenizer
from data.load import get_batch_size, get_data
from metrics import BCELoss
from models import BertSumExt
from torchtrain import Trainer
from torchtrain.metrics import AverageAggregator
from torchtrain.utils import set_random_seeds
def get_args():
parser = ArgumentParser()
parser.add_argument("--seed", type=int, default=233666)
parser.add_argument("--run_ckp", default="")
parser.add_argument("--run_dataset", default="val")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--warmup", type=int, default=10000)
parser.add_argument("--stepwise", action="store_false")
# torchtrain cfgs
parser.add_argument("--max_n", type=int, default=50000)
parser.add_argument("--val_step", type=int, default=1000)
parser.add_argument("--save_path", default="/tmp/runs")
parser.add_argument("--model_name", default="BertSumExt")
parser.add_argument("--cuda_list", default="2,3")
parser.add_argument("--grad_accum_batch", type=int, default=1)
parser.add_argument("--train_few", action="store_true")
return vars(parser.parse_args())
def get_param_grid():
param_grid = [
{"pretrained_model_name": ["voidful/albert_chinese_tiny"], "lr": [6e-5]},
]
return ParameterGrid(param_grid)
def get_cfg(args={}, params={}):
cfg = {**args, **params}
# other cfgs
return cfg
def run(cfg):
set_random_seeds(cfg["seed"])
tokenizer = BertTokenizer.from_pretrained(cfg["pretrained_model_name"])
bert = AutoModel.from_pretrained(cfg["pretrained_model_name"])
data_iter = get_data(
cfg["batch_size"], tokenizer, bert.config.max_position_embeddings
)
model = BertSumExt(bert)
optimizer = Adam(model.parameters(), lr=cfg["lr"])
scheduler = LambdaLR(
optimizer,
lambda step: min(step ** (-0.5), step * (cfg["warmup"] ** (-1.5)))
if step > 0
else 0,
)
criteria = {"loss": AverageAggregator(BCELoss())}
trainer = Trainer(
model,
data_iter,
criteria,
cfg,
optimizer,
scheduler,
get_batch_size=get_batch_size,
)
if cfg["run_ckp"]:
return trainer.test(cfg["run_ckp"], cfg["run_dataset"])
return trainer.train(stepwise=cfg["stepwise"])
def main():
param_grid = get_param_grid()
for i, params in enumerate(param_grid):
print("Config", str(i + 1), "/", str(len(param_grid)))
cfg = get_cfg(get_args(), params)
metrics = run(cfg)
print("Best metrics:", metrics)
if __name__ == "__main__":
main()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchtrain-0.4.9.tar.gz
(9.4 kB
view details)
Built Distribution
File details
Details for the file torchtrain-0.4.9.tar.gz
.
File metadata
- Download URL: torchtrain-0.4.9.tar.gz
- Upload date:
- Size: 9.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9a78b81f83245377ee7a4ff09120bb71ffc78481adece844ca23318a1d7fbb45 |
|
MD5 | 80db95479d6ba75572e9787f5e814069 |
|
BLAKE2b-256 | aaf44425c0166930b21b1489b9b5c85e210334e94deb0222c0deb45ed66d1a99 |
Provenance
File details
Details for the file torchtrain-0.4.9-py3-none-any.whl
.
File metadata
- Download URL: torchtrain-0.4.9-py3-none-any.whl
- Upload date:
- Size: 9.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9114371e74bf3f409c261bdcbb2507d367d0ab1c4d47c264929f800d106af9a0 |
|
MD5 | 8d79582f05bbbf90fb86b8a704e0ab55 |
|
BLAKE2b-256 | 261104df7e1eb872e247be9cbc096e988c90fd1b3365f7c4f31e6469b276782b |