Skip to main content

IrisML tasks for pytorch training

Project description

irisml-tasks-training

This is a package for IrisML training-related tasks.

See irisml repository for the detail of irisml framework.

Tasks

train

Train a pytorch model. A model object must have "criterion" and "predictor" property. See the documents for the detail. Returns a trained model.

predict

Run inference with a given pytorch model. Returns prediction results.

append_classifier

Append a classifier layer to an encoder model.

build_classification_prompt_dataset

Convert a multiclass classification Image Dataset into a dataset with text prompts.

build_zero_shot_classifier

Build a classifier FC layer from text features. See the CLIP repo for the detail.

create_classification_prompt_generator

Create a prompt generator for classification task.

export_onnx

Trace a pytorch model and export it as ONNX using torch.onnx.export(). Throws an exception if it couldn't export. Returns an exported onnx model.

evaluate_accuracy

Calculate top1 accuracy for given prediction results. It supports only image classification results.

evaluate_detection_average_precision

Calculate mAP for object detection results.

get_targets_from_dataset

Get a list or a tensor of targets from a Dataset.

make_feature_extractor_model

Make a new model to extract intermediate features from the given model. Use the predict task to run the extractor model.

make_image_text_contrastive_model

Make a new model to run image-text contrastive training like CLIP.

make_image_text_transform

Make a transform function that can be used for a contrastive training

split_image_text_model

Extract image_model and text_model from a image-text model.

Available plugins for train task.

  • log_summary
  • log_tensorboard
  • progressbar

Interfaces for training and prediction

The tasks in this package expects the following interfaces

Notations

  • input: An input object for a single example. For example, an image tensor.
  • target: A ground truth for a single example.
  • inputs_batch: A batch of input.
  • targets_batch: A batch of target

Model

class Model(torch.nn.Module):
    def training_step(self, inputs_batch, targets_batch):  # Returns {'loss': loss_tensor}
        pass

A model for training must implement training_step() method. The trainer will provide inputs and targets to the method. It must return a dictionary containing 'loss' entry.

class Model(torch.nn.Module):
    def prediction_step(self, inputs_batch):  # Returns prediction results
        pass 

Similarily, a model for prediction must have 'prediction_step()' method. Inputs will be provided to this method and it must return prediction results.

For most of the case, a model implements both methods, training_step() and prediction_step().

Dataset

The trainer accepts an instance of torch.utils.data.Dataset class. For each index, it must return a tuple (raw_input, target). Curretly, raw_input must be a RGB PIL Image object.

Transform

A transform function must return (input, target) given (raw_inputs, target).

Inputs and targets formats

Multiclass Image classification

  • input: A float tensor [3, H, W] that represents a RGB image. Its value range is [0-1].
  • inputs_batch: A float tensor [N, 3, H, W] if all inputs have the same shape. Otherwise, a list of input.
  • target: an integer tensor that represents a class index.
  • targets: An integer tensor [N, 1].

Multilabel Image Classification

  • inputs, inputs_batch: Same with above
  • taget: An integer tensor [num_classes]. Its value is 0 (negative) or 1 (positive).
  • targets_batch: An integer tensor [N, num_classes]

Object Detection

  • inputs, inputs_batch: Same with above
  • target: A float tensor [num_boxes, 5]. Each bounding box is represented as [class_index, x0, y0, x1, y1]. x0, y0, x1, y1 is relative coordinates of the left, top, right, bottom of the box. 0 <= x0 < x1 <= 1 and 0 <= y0 < y1 <= 1.
  • targets_batch: A list of targets

Image Segmentation

  • inputs, inputs_batch: Same with above
  • target: A float tensor [num_classes, H, W]. Its value is 0 (negative) or 1 (positive) for each pixel on the sample.
  • targets: A float tensor [N, num_classes, H, W]

CLIP Zero-shot classifier build

build_zero_shot_classifier task has a different interface. It doesn't require a Model instance. Instead, it requires two tensors, text_features and text_labels.

  • text_features: A float tensor [N, feature_size].
  • text_labels: An integer tensor[N, 1] that represents a class index for each text.

Project details


Release history Release notifications | RSS feed

This version

0.0.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

irisml-tasks-training-0.0.3.tar.gz (24.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

irisml_tasks_training-0.0.3-py3-none-any.whl (28.9 kB view details)

Uploaded Python 3

File details

Details for the file irisml-tasks-training-0.0.3.tar.gz.

File metadata

  • Download URL: irisml-tasks-training-0.0.3.tar.gz
  • Upload date:
  • Size: 24.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for irisml-tasks-training-0.0.3.tar.gz
Algorithm Hash digest
SHA256 e95350ee7ae4d8e5414b07833e682ad64b8fd4b134a3b29ba41b56240432c4b3
MD5 155bd3fe93e9c654c9ea56b87fb87b6c
BLAKE2b-256 b2bf935527835047b5bd9154017f0bdafd2491ae7720f08d49b00b4ecf47e37f

See more details on using hashes here.

File details

Details for the file irisml_tasks_training-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for irisml_tasks_training-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 9ef4b82a352b3a50d7ff7fec1a50776f31407225d8ab6c3fabe6ecddfb7ebefb
MD5 729fd7d65dcca3b55e2edfbd32f30c7a
BLAKE2b-256 12fbf06cf92fd1fa54ad98df78673cb9099ed4962f20b921e51941c0787b0da9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page