Skip to main content

training manager and loggers

Project description

PyPI

Checks Coverage

code size, bytes

trnbl -- Training Butler

If you train a lot of models, you might often find yourself being annoyed at swapping between different loggers and fiddling with a bunch of if batch_idx % some_number == 0 statements. This package aims to fix that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal local logging solution (live demo) is provided.

  • This interface handles logging, error messages, metrics, and artifacts.
  • Swapping from one logger to another requires no modifications except initializing the new logger you want, and passing that instead.
  • You can even log to multiple loggers at once!

Secondly, a TrainingManager class is provided which handles logging, artifacts, checkpointing, evaluations, exceptions, and more, with flexibly customizable intervals.

  • Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.

    • "1/10 runs" -- 10 times a run
    • "2.5 epochs" -- every 2 & 1/2 epochs
    • (100, "batches") -- every 100 batches
    • "10k samples" -- every 10,000 samples
  • an evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary

  • checkpointing is handled automatically, specifying an interval in the same way as evaluations

  • models are saved at the end of the run, or if an exception is raised, a model.exception.pt is saved

Installation

pip install trnbl

Usage

also see the notebooks/ folder:

import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager

# set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
	project="iris-demo",
	metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
	train_config=dict(
		model=str(model), optimizer=str(optimizer), criterion=str(criterion)
	),
)

with TrainingManager(
	# pass your model and logger
	model=model,
	logger=logger,
	evals={
		# pass evaluation functions which take a model, and return a dict of metrics
		"1k samples": my_evaluation_function,
		"0.5 epochs": lambda model: logger.get_mem_usage(),
		"100 batches": my_other_eval_function,
	}.items(),
	checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:

	# wrap the loops, and length will be automatically calculated
	# and used to figure out when to run evals, checkpoint, etc
	for epoch in tr.epoch_loop(range(120)):
		for inputs, targets in tr.batch_loop(TRAIN_LOADER):
			# your normal training code
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, targets)
			loss.backward()
			optimizer.step()

			# compute whatever you want every batch
			accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
			
			# log the metrics
			tr.batch_update(
				samples=len(targets),
				**{"train/loss": loss.item(), "train/acc": accuracy},
			)

	# a `model.final.pt` checkpoint will be saved at the end of the run,
	# or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you're too lazy to set up a new wandb project for a quick test, and want to be able to easily read the logs. It logs everything as json or jsonl files, and provides a simple web interface for viewing the data. The web interface allows:

  • enable or disable the visibility of individual runs
  • filter and sort runs by various stats via an interactive table
  • smooth the data and change axes scales
  • move and resize all plots and tables

You can view a live demo of the web interface here.

TODOs:

  • BUG: minifying the html/js code causes things to break?

  • frontend:

    • batch/epoch size to table in config column group
    • box to add aliases to runs
    • customizable grid snap size?
    • display the grid on the background?
  • deployment:

    • demo website for local logger
    • CI/CD for website, minification, tests, etc
    • migrate to typescript

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trnbl-0.0.2.tar.gz (60.1 kB view details)

Uploaded Source

Built Distribution

trnbl-0.0.2-py3-none-any.whl (63.3 kB view details)

Uploaded Python 3

File details

Details for the file trnbl-0.0.2.tar.gz.

File metadata

  • Download URL: trnbl-0.0.2.tar.gz
  • Upload date:
  • Size: 60.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for trnbl-0.0.2.tar.gz
Algorithm Hash digest
SHA256 caab016a97ace82f890666ae90d915af601ca499655d32e48a9f88c9d7b1e8a8
MD5 031d5f593b941cf7943654977e8e433e
BLAKE2b-256 7a29fd0ab27506739ad4609089bc3abb5dc2b567a17a42bb8282c6d5b461792b

See more details on using hashes here.

File details

Details for the file trnbl-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: trnbl-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 63.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for trnbl-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3eb06f6a5f8b80926a652983c5a2736f2ff385cf2290e0425d96dd5d17a4e193
MD5 d9cbbafac7e5c01e60383dd22252274f
BLAKE2b-256 3ba3b0ee22485bdc772d6997788c7b23aa1576c762aba33475c8a0200d337bf2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page