Skip to main content

functions to make working in pytorch quik-er

Project description

example workflow License Docker Repository on Quay code style codecov


For the quick-er distribution of neural nets in PyTorch

As I was building out the same set of code for a recommender system, a BERT sentiment model, and a co-worker was about to build a classification model, I decided to standardize the code into this package. It's lightweight, because I didn't want to hide the standard steps to keep the user from learning Neural Networks, but I also didn't want to maintain the code in multiple places.


v0.3.0 (10/25/2021)

  • Integrated ray tune for hyperparameter optimization using use_ray_tune = True
  • Added needed unit tests to bring code coverage to 85%
  • Updated README

v0.2.0 (07/22/2021)

  • Integrated MLflow into QuikTrek using QuikMlflow and use_mlflow = True
  • Added utilities to create all required files for a transformers TorchServe model archive
  • Added utilities to build a dataset and query a TorchServe API

v0.1.0 (06/03/2021)

  • Refactored to provide classes QuikTrek, QuikTraveler, QuikData
  • Slight change from 0.0.1, where now you need trek =, gpu); tr =
  • Added documentation


pip install pytorch-quik


Intro, QuikTrek, QuikTraveler

In its simplest form, you'll want to:

  • create a QuikTrek (an object that stores overall attributes and various travelers across your neural network)
  • create a QuikTraveler (an object that "travels" forward and backward on your neural network)
  • add QuikData
  • add a criterion, an optimizer
  • train
import pytorch_quik as pq
from torch import nn, optim

trek =
tr =
tr.add_data("train", tr.args))

# from
model = nn.Sequential(nn.Linear(3, 1), nn.Flatten(0, 1))


for epoch in range(tr.epochs):
    for batch in
        users, items, labels = [ for tens in batch]
        outputs = tr.model.forward(users, items)
        loss = tr.criterion(outputs, labels)
        tr.criterion.step(), tr.args, epoch)


A little more about how the data pull works. I usually run my project from my repo, and from the command line. Either way, QuikData expects your tensors, models, and state_dicts to be in your current path, and in a data subfolder. It will also expect to use your traveler's arguments to determine the date for the filename. You can set it using argparse like I do, or throw it in to your script like this:

tr =
tr.args.data_date = '20210101'
tr.add_data("train", tr.args))

The file will also have the words tensor, model, or state_dict in it, train/valid/test if it's a tensor, and e* if it's a state_dict, where e* stands for which epoch the state_dict was saved in. Therefore, the file structure using your repo as the base directory may look like this (I also save my original state_dict to make sure my model is actually training):

  • data/
  • data/
  • data/
  • data/
  • data/
  • data/
  • data/

Model and Functions

Model State

Sometimes your training and validation losses will converge sooner than expected, and you'll want to test an epoch before the final one. this is possible, because the function will save the weights and biases at the end of the epoch to disk.

Set_* Functions

Setting the criterion, optimizer, and scheduler just takes a callback, and can use both general defaults and specific ones. For instance, I have an OptKwargs class that can receive parameters via argparse that most optimizers have (lr, weight_decay, eps, and betas), but then you can also feed in specific parameters like amsgrad=False if you are using Adam at instantiation like this: tr.set_optimizer(Adam, amsgrad=False). For simplicity I didn't use a scheduler above, but you could include something like tr.set_scheduler(OneCycleLR), and then after your backward, include a tr.scheduler.step().

Distributed Data Parallel (DDP)

Intro to DDP

This is really why pytorch-quik is quick-er for me. Adding in DDP can be tough, and I tried to do so and allow you to switch back and forth when necessary. I run my code on an OpenShift cluster, and sometimes can't get on my multi-GPU setup. This allows me to just use a different set of args and just deal with slower code, not broken code!

I would suggest spending time setting up argparse so that you can have your own default arguments for batch size, learning rate, etc, but if you don't want to, you deal with my defaults. These assume you have a GPU on 1 node, which is the simplest benefit from pytorch-quik:

from argparse import ArgumentParser
parser = pq.args.add_learn_args(ArgumentParser())
args = parser.parse_args()
gpu = 0
tr =, gpu)

Notice the addition of providing the QuikTraveler your args, as well as telling it to run on your GPU 0. If you were truly distributing this across GPUs, you'd have to spawn QuikTravelers on each GPU, but more on that later. Also, the tr.run_prep() will start your DDP process_group.

Automated Mixed Precision (AMP)

Tangentally related is AMP, and if your model.forward() is already set up with mixed precision, this should work for you also. Just add args.mixed_precision = True before creating your traveler, and add with tr.amp.caster before and within your forward like so (you will have to change your with autocast(): to be a with myparam where myparam is what we're sending in here:

with tr.amp.caster:
  outputs = tr.model.forward(users, items, tr.amp.caster)
  loss = tr.criterion(outputs, labels)

Here is my .forward:

def forward(self, users, items, caster):
  with caster:


This isn't true logging (though there is some), but this is a good place to talk about the progress bar and metrics. If your training/validation is distributed, your loss will eventually be aggregated with ring-all reduce on each GPU, so it doesn't matter on which one you calculate loss. So, I make your GPU 0 your "logger" GPU, and it will be the one with = True. When you see this used, it's just telling the progress bar to be drawn, or the metrics to be calculated.

tqdm Progress Bar

One benefit of this is the tqdm progress bar. I didn't bother with the notebook one so that it can work distributed across GPUs, with IPython, or in a Jupyter Notebook. It should look something like this:

epoch: 1/2:  22%|████████████▏                                          | 1020/4591 [00:20<01:13, 48.69it/s]

I take care of running it only on your first GPU for you, and calculate your steps based on how your bunches distribute across samples. So, your initialization looks like this:

trtq = pq.ddp.tq_bar(, epoch, tr.epochs)

and then your update and close are as they would normally be with bar. I also suggest you del your bar because it can have trouble drawing the next one if you don't.


One cool addition that I pulled from (thanks!) Is the way that they pull training loss and validation loss as they are training, and show them at the end of each epoch so you can see your progress. It works everywhere because it's just a pandas df, and it passes between training and validation because it's a class attribute. (Here we're only using training, so there's zeros). After the backward, you just need a tr.metrics.add_loss(loss) to accumulate it, and then after your loop a tr.metrics.write() to build the DataFrame. Then you can access it whenever you want at tr.metrics.results:

   epoch    train_loss  valid_loss   time
0      0        0.6435         0.0  05:03
1      1        0.5281         0.0  06:20


Pytorch-quik class instance attributes come in as arguments, and are (mostly) stored in dataclasses.


There are tons of arguments to set when you're training your model. These are split into various types:

  • DDP args: (node rank, nodes, gpus, number of workers)
  • learning args: (epochs, mixed precision, batch size, learning rate, weight decay, betas, epc, and find_unused_paramters)
  • ray tune args: (use_ray_tune, number of samples)
  • MLflow args: (use_mlflow, experiment, user, tracking URI, endpoint URL)

There are four ways to set arguments with pytorch-quik:

  • Use the defaults: pq.arg.add_learn_args(parser)
  • Provide them at the command line: --learning_rate 1.5e-6
  • Defaulted within your script: pq.arg.add_learn_args(parser, {"learning_rate": 1.5e-6)
  • And of course, fix them afterward (but check their name in the Namespace: = 1.5e-6

Putting all these things together, a simple setup would be:

from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser = pq.arg.add_ddp_args(parser)
parser = pq.arg.add_learn_args(parser)
parser = pq.arg.add_mlflow_args(parser)
parser = pq.arg.add_ray_tune_args(parser)
args = parser.parse_args()


Similar to arguments, I've organized the attributes within a QuikTrek as three dataclasses and QuikMlflow as one dataclass containing keyword arguments (kwargs):

  • DlKwargs: Data Loader kwargs such as batch size, pin memory, and number of workers
  • OptKwargs: Optimizer kwargs such as learning rate, weight decay, and betas (these attempt to use universal names like lr)
  • World: This is playing off the distributed kwarg of world_size, and is really anything about the world, such as nodes, gpus, init group, and ray tune
  • MlfKwargs: MLflow kwargs such as tracking URI, endpoint URL and is_parent

Usage Summary

Putting this all together, moving the argument creation into a main, and moving the rest into a function (let's call it train), which takes just two parameters, args and gpu. This would be my full example using all the above:

import pytorch_quik as pq
from torch import nn, optim
from argparse import ArgumentParser

def train(args, gpu):
    trek =, gpu)
    tr =

    tr.args.data_date = '20210101'
    tr.add_data("train", tr.args))

    # from
    model = nn.Sequential(nn.Linear(3, 1), nn.Flatten(0, 1))

    if, tr.args, "orig")

    for epoch in range(tr.epochs):
            trbar = pq.ddp.tq_bar(, epoch, tr.epochs)
        for batch in
            users, items, labels = [ for tens in batch]
            with tr.amp.caster:
                outputs = tr.model.forward(users, items, tr.amp.caster)
                loss = tr.criterion(outputs, labels)
            del trbar
  , tr.args, epoch)

def main():
    parser = pq.args.add_learn_args(ArgumentParser())
    args = parser.parse_args()
    gpu = 0
    train(args, gpu)

if __name__ == "__main__":

Additional Usage Information

Spawning across GPUs... or not!

This is why I love pytorch-quik. then you can use my traverse function that will make sure your GPUs are available to PyTorch, and if so, send your training across each GPU! If it decides you can't, it send use that same exact code to the CPU run it on there. This is what they would all look like (I'll abbreviate train):

import pytorch_quik as pq
from torch import nn, optim
from argparse import ArgumentParser

def train(args, gpu):

def main():
    parser = pq.args.add_learn_args(ArgumentParser())
    args = parser.parse_args()
    args.gpus = 4
    pq.ddp.traverse(train, args)

if __name__ == "__main__":

Running from command line

Now, you can package your repo, not even touch the code, and run it on four GPUs like python --gpus 4, on one gpu when there are four like python --gpus 1, or throw it on a cluster, let the cluster decide where to run it, and when you run python, pytorch-quik will determine if you have GPUs and whether it should distribute or not!

Cool, huh.

Let me know if you have any questions, and I'll keep adding to this documentation!


MLflow integration


MLflow for model tracking has been integrated with QuikTrek. When instantiated, if use_mlflow=True, then a QuikFlow is instantiated, as well as a run, which involves:

  • checking if args.experiment exists, and if not, creates it
  • creates a run under the experiment
  • adds appropriate run tags
  • adds data loader, optimizer, and world dataclasses as run parameters


Next, as training occurs, if the QuikTraveler add_loss function is used, the "train_loss" metric will be added to MLFlow after every step, and the "valid_loss" will be added after every epoch. It's best to use step, since this will align the points on the MLflow graph based on training process, not clock time.

Saving artifacts and accuracy metrics

In addition, artifacts are saved when the following QuikTraveler:

  • save_state_dict now saves each epoch's state dict as an artifact in the run
  • record_results: records the classification report results as metrics, and saves the png of the confusion matrix as an artifact

Ray Tune integration

As it's difficult to manage the execution (and tracking) of hundreds of training runs, pytorch-quik attempted to use both Optuna and Ray Tune. Both were successful integrations, but optuna lacked the ability to work with DDP. If there are multiple GPUs on your platform, then ray tune will execute a run on each GPU. When use_ray_tune=True, the following options are available:

  • Creating a parent (with is_parent=True) in MLflow, to contain all runs in a collapsible parent run
  • Using get_tune_config to pull the hyperparameter options from a yaml file, such as:
            categories: [8, 12, 16, 24]
            categories: [0, 0.001]
  • using run_ddp_tune to have all test permutations be distributed across GPUs

Huggingface integration

As many of my models are transformers, the following options are available (only yet tested on BERT and RoBERTa models:

  • get_bert_info: using arg.bert_type, will pull model information for a sequence classification model
  • get_pretrained_model: using arg.bert_type, will pull the model itself
  • get_encodings: using arg.bert_type, will pull the model's tokenizer and encode text
  • save_bert_model: pull a state dict, and add it to the model before using transformer's save_pretrained
  • save_tokenizer: save the tokenizer, and also update the tokenizer.json to include the correct pathing for serving

Torch Serve integration

TS Functionality

Finally, in order to productionalize a pytorch model, TorchServe is an excellent choice. There is quite a bit of overhead to get it up and running, so I've provided some helpful functions. The main ones to note are:

  • build_extra_files: There are many files needed to build a model archive. This function will build:
    • setup_config.json
    • index_to_name.json
    • sample_text.json
    • A default transformer handler if you don't have one (transformer_handler_pq)
  • create_mar: First checks to see if you have all the required files, then uses the command line function to create your torch model archive.

I should note the benefit of using my customized handler, is that it will provide the predictions like Huggingface's example, but will also provide the logits as required by my stakeholder.

TS Usage

In order to create your torch model archive (mar) file, there are only a few steps. When I do so, I pull my state dict from MLflow, which makes it even more modular. If you create a function called parse_my_args() including all the pq.args in the Arguments snippet, and INDEX_LABELS is an OrderedDict of your labels, here's an example where you (1) pull the state dict, (2) download the model, (3) save the model and files to disk, and (4) create the mar:

        args = parse_my_args() # this would be all pq.arg functions you'd need
        mlf = pq.mlflow.QuikMlflow(args)
        serve_path ="", args).parent.joinpath("serve")
        filter_str = " = '1.6e-06'" # some filter to differentiate the specific run for your state dict
        args.data_date = mlf.get_state_dict(args.test_epoch, filter_string=filter_str, serve_path=serve_path)
        model = pq.bert.get_pretrained_model(labels=list(INDEX_LABELS.keys()), bert_type=args.bert_type)
        pq.bert.save_bert_model(model, args, args.test_epoch, serve_path)
        pq.serve.build_extra_files(args, INDEX_LABELS, serve_path)
        pq.serve.create_mar(args, serve_path)

Torch Serve API Usage

Now that you have a torch serve mar (let's call it my_model.mar, and have started it (hopefully in a container!) by running torchserve --start --models "my_tc=my_model.mar" --ncs, then you can use pytorch-quik's batch_inference function. The benefits of this function are:

  • it takes an array of records, vs. sending one input at a time to the API!
  • formats the text as the handler expects (json with an "inference" header and a "data" key for each record)
  • batches the API calls so as to not overload the API (I usually use a batch_size of 20-35 records)
  • opens a session with the API, and retries whenever you receive a 507 error (which is common for me)
  • normalizes the response and drops it into a pandas dataframe.

Usage is pretty easy. If your test set is from sklearn train test split, then this should work (INDEX_LABELS is an OrderedDict of your labels):

df = pq.api.batch_inference(Xte[:, 0], INDEX_LABELS, batch_size=20)

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-quik-0.3.3.tar.gz (45.1 kB view hashes)

Uploaded source

Built Distribution

pytorch_quik-0.3.3-py2.py3-none-any.whl (45.6 kB view hashes)

Uploaded py2 py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page