Skip to main content

A small tool for PyTorch training

Project description

🔥 torchtrain 💪

A small tool for PyTorch training.

Features

  • Avoid boilerplate code for training.
  • Stepwise training.
  • Automatic TensorBoard logging, and tqdm bar.
  • Count model parameters and save hyperparameters.
  • DataParallel.
  • Early stop.
  • Save and load checkpoint. Continue training.
  • Catch out of memory exceptions to avoid breaking training.
  • Gradient accumulation.
  • Gradient clipping.
  • Only run few epochs, steps and batches for code test.

Install

pip install torchtrain

Example

Check doc string of Trainer class for detailed configurations.

An incomplete minimal example:

data_iter = get_data()
model = Bert()
optimizer = Adam(model.parameters(), lr=cfg["lr"])
criteria = {"loss": AverageAggregator(BCELoss())}
trainer = Trainer(cfg, data_iter, model, optimizer, criteria)
trainer.train(stepwise=True)

Or this version:

from argparse import ArgumentParser

from sklearn.model_selection import ParameterGrid
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoModel, BertTokenizer

from data.load import get_batch_size, get_data
from metrics import BCELoss
from models import BertSumExt
from torchtrain import Trainer
from torchtrain.metrics import AverageAggregator
from torchtrain.utils import set_random_seeds


def get_args():
    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=233666)
    parser.add_argument("--run_ckp", default="")
    parser.add_argument("--run_dataset", default="val")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--warmup", type=int, default=10000)
    parser.add_argument("--stepwise", action="store_false")
    # torchtrain cfgs
    parser.add_argument("--max_n", type=int, default=50000)
    parser.add_argument("--val_step", type=int, default=1000)
    parser.add_argument("--save_path", default="/tmp/runs")
    parser.add_argument("--model_name", default="BertSumExt")
    parser.add_argument("--cuda_list", default="2,3")
    parser.add_argument("--grad_accum_batch", type=int, default=1)
    parser.add_argument("--train_few", action="store_true")
    return vars(parser.parse_args())


def get_param_grid():
    param_grid = [
        {"pretrained_model_name": ["voidful/albert_chinese_tiny"], "lr": [6e-5]},
    ]
    return ParameterGrid(param_grid)


def get_cfg(args={}, params={}):
    cfg = {**args, **params}
    # other cfgs
    return cfg


def run(cfg):
    set_random_seeds(cfg["seed"])
    tokenizer = BertTokenizer.from_pretrained(cfg["pretrained_model_name"])
    bert = AutoModel.from_pretrained(cfg["pretrained_model_name"])
    data_iter = get_data(
        cfg["batch_size"], tokenizer, bert.config.max_position_embeddings
    )
    model = BertSumExt(bert)
    optimizer = Adam(model.parameters(), lr=cfg["lr"])
    scheduler = LambdaLR(
        optimizer,
        lambda step: min(step ** (-0.5), step * (cfg["warmup"] ** (-1.5)))
        if step > 0
        else 0,
    )
    criteria = {"loss": AverageAggregator(BCELoss())}
    trainer = Trainer(
        cfg,
        data_iter,
        model,
        optimizer,
        criteria,
        scheduler,
        get_batch_size=get_batch_size,
    )
    if cfg["run_ckp"]:
        return trainer.test(cfg["run_ckp"], cfg["run_dataset"])
    return trainer.train(stepwise=cfg["stepwise"])


def main():
    param_grid = get_param_grid()
    for i, params in enumerate(param_grid):
        print("Config", str(i + 1), "/", str(len(param_grid)))
        cfg = get_cfg(get_args(), params)
        metrics = run(cfg)
        print("Best metrics:", metrics)


if __name__ == "__main__":
    main()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtrain-0.4.3.tar.gz (9.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchtrain-0.4.3-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file torchtrain-0.4.3.tar.gz.

File metadata

  • Download URL: torchtrain-0.4.3.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.8.2

File hashes

Hashes for torchtrain-0.4.3.tar.gz
Algorithm Hash digest
SHA256 1b038dc322a6ae0b7a83ea406ae0d7c1fd9e0d8dbd9df4b6c4441214c0fc7ae0
MD5 d856a2d1fe349b341dc6e0bc840cc9f8
BLAKE2b-256 9c9c75d6e4b64c1ba7f35489a46230da2bf03fc96ec3ca0f670c389eef61b3b6

See more details on using hashes here.

File details

Details for the file torchtrain-0.4.3-py3-none-any.whl.

File metadata

  • Download URL: torchtrain-0.4.3-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.8.2

File hashes

Hashes for torchtrain-0.4.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7e7911e2575cc438c8ae312f1f6560bd62767e6129aa1baa83e0162b1be34bec
MD5 ca53a02e213ddc4349c7aacb6e69f320
BLAKE2b-256 2660f89e25a92fac8290701176402f8f8ff4efea5930daadc03056d9ddff339b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page