Skip to main content

A Framework for creating a boilerplate template for ai projects that are ready for MLOps

Project description

Machine Learning Extension for pyPhases

This Extension adds:

  • an Exporter for PyTorch and TensorFlow Models.
  • an Modelmanager that can handle PyTorch and TensorFlow Models

Setup

  • add pyPhasesML to your dependencies or run pip install -U pyPhasesML
  • add pyPhasesML to your plugins in the main project config f.e: in your project.yaml
name: bumpDetector
namespace: ibmt.tud

# load machine learning plugin
plugins:
  - pyPhasesML
  • you do not need to add the ModelExporter manually

Getting startet

Minimal Example

For a complete minimal example see, with loading data, training and evaluation see: https://gitlab.com/tud.ibmt.public/pyphases/pyphasesml-example-bumpdetector

Adding required config values

These values can be changed in your config. The default values are shown here:

modelPath: models/mymodels

# the name of the model (also defines the path: models/MyCnn/MyCnn.py)
modelName: CNNPytorch

# the model config for a specific model
model:
    kernelSize: 3

alwaysIgnoreClassIndex: null
inputShape: [16, 50]
oneHotDecoded: False

trainingParameter:
  useEventScorer: false
  stopAfterNotImproving: false
  maxEpochs: false
  batchSize: 32
  validationEvery: false
  optimizer: false
  batchSizeValidation: 32
  learningRate: 0.001
  learningRateDecay: 0.001
  validationMetrics: ["acc", "kappa"]

classification:
  type: classification
  classNames: [A, B]
  classWeights: [0.6, 0.4]

Adding a PyTorch Model CNNPytorch

Create a class that is compatible with your modelPath and modelname. So in this example, we need a class CNNPytorch in the path models/mymodels/CNNPytorch.py relative to your root.

This class is required to:

  • inherit from ModelTorchAdapter:
  • populate the self.model with a valid PyTorch-Model, in the define method
  • return a valid loss function in the method getLossFunction
import torch.nn as nn

from pyPhasesML.adapter.ModelTorchAdapter import ModelTorchAdapter

class CNNPytorch(ModelTorchAdapter):
    def define(self):
        length, channelCount = self.inputShape
        numClasses = self.numClasses

        self.model = nn.Conv1d(
            in_channels=channelCount, 
            out_channels=self.numClasses,
            kernel_size=self.getOption("kernelSize"),
        )

    def getLossFunction(self):
        return torch.nn.MultiLabelSoftMarginLoss(reduction="mean", weight=self.weightTensors)

Load the model

In a phase you can simply use the ModelManager to get the Model and registerData to save the state. There is no dependency on pyTorch or TensorFlow in this example, so you swap your models dynamicly depending on your environment:

import numpy as np
from pathlib import Path

from pyPhases import Phase
from pyPhasesML import DatasetWrapXY, ModelManager, TrainingSet


class TestModel(Phase):
    def main(self):
        # loads the model depending on modelPath and modelName
        model = ModelManager.getModel()
        
        input = np.randn(20, 16, 50)        
        output = model(input)
        # save the model state
        self.project.registerData("modelState", model)

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyPhasesML-0.7.1.tar.gz (28.8 kB view details)

Uploaded Source

File details

Details for the file pyPhasesML-0.7.1.tar.gz.

File metadata

  • Download URL: pyPhasesML-0.7.1.tar.gz
  • Upload date:
  • Size: 28.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for pyPhasesML-0.7.1.tar.gz
Algorithm Hash digest
SHA256 aa9601c7420c06780149f7f8d899a2611386ce241b53bd18260ff5ea0e4a192a
MD5 52eac5f17486476a166a60881905e466
BLAKE2b-256 cb103337e865cdbb6e468c34b8030f99916421b03034aa789f5af1f93f527014

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page