Skip to main content

A Framework for creating a boilerplate template for ai projects that are ready for MLOps

Project description

Machine Learning Extension for pyPhases

This Extension adds:

  • an Exporter for PyTorch and TensorFlow Models.
  • an Modelmanager that can handle PyTorch and TensorFlow Models

Setup

  • add pyPhasesML to your dependencies or run pipp install -U pyPhasesML
  • add pyPhasesML to your plugins in the main project config f.e: in your project.yaml
name: bumpDetector
namespace: ibmt.tud

# load machine learning plugin
plugins:
  - pyPhasesML
  • you do not need to add the ModelExporter manually

Getting startet

Minimal Example

For a complete minimal example see, with loading data, training and evaluation see: https://gitlab.com/tud.ibmt.public/pyphases/pyphasesml-example-bumpdetector

Adding required config values

These values can be changed in your config. The default values are shown here:

modelPath: models/mymodels

# the name of the model (also defines the path: models/MyCnn/MyCnn.py)
modelName: CNNPytorch

# the model config for a specific model
model:
    kernelSize: 3

alwaysIgnoreClassIndex: null
inputShape: [16, 50]
oneHotDecoded: False

trainingParameter:
  useEventScorer: false
  stopAfterNotImproving: false
  maxEpochs: false
  batchSize: 32
  validationEvery: false
  optimizer: false
  batchSizeValidation: 32
  learningRate: 0.001
  learningRateDecay: 0.001
  validationMetrics: ["acc", "kappa"]

classification:
  type: classification
  classNames: [A, B]
  classWeights: [0.6, 0.4]

Adding a PyTorch Model CNNPytorch

Create a class that is compatible with your modelPath and modelname. So in this example, we need a class CNNPytorch in the path models/mymodels/CNNPytorch.py relative to your root.

This class is required to:

  • inherit from ModelTorchAdapter:
  • populate the self.model with a valid PyTorch-Model, in the define method
  • return a valid loss function in the method getLossFunction
import torch.nn as nn

from pyPhasesML.adapter.ModelTorchAdapter import ModelTorchAdapter

class CNNPytorch(ModelTorchAdapter):
    def define(self):
        length, channelCount = self.inputShape
        numClasses = self.numClasses

        self.model = nn.Conv1d(
            in_channels=channelCount, 
            out_channels=self.numClasses,
            kernel_size=self.getOption("kernelSize"),
        )

    def getLossFunction(self):
        return torch.nn.MultiLabelSoftMarginLoss(reduction="mean", weight=self.weightTensors)

Load the model

In a phase you can simply use the ModelManager to get the Model and registerData to save the state. There is no dependency on pyTorch or TensorFlow in this example, so you swap your models dynamicly depending on your environment:

import numpy as np
from pathlib import Path

from pyPhases import Phase
from pyPhasesML import DatasetWrapXY, ModelManager, TrainingSet


class TestModel(Phase):
    def main(self):
        # loads the model depending on modelPath and modelName
        model = ModelManager.getModel()
        
        input = np.randn(20, 16, 50)        
        output = model(input)
        # save the model state
        self.project.registerData("modelState", model)

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyPhasesML-0.4.5.tar.gz (25.0 kB view details)

Uploaded Source

File details

Details for the file pyPhasesML-0.4.5.tar.gz.

File metadata

  • Download URL: pyPhasesML-0.4.5.tar.gz
  • Upload date:
  • Size: 25.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for pyPhasesML-0.4.5.tar.gz
Algorithm Hash digest
SHA256 8135fb64ffe18887d7d8f4ec3f75f08bd3f40508ea03bcc1fc54167fd7edd456
MD5 9181c817201447ebed510f008e52e759
BLAKE2b-256 03697dde66255de6b00b4bdea20530e3b54b838b830378c4d8bb5f8d43d97231

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page