Skip to main content

qoala.id data science team library

Project description

Qoala Ai Library

This project contains the collection of deep learning model wrappers from Qoala.id data science team.

News

Horey Version
Deeplab Semantic segmentation was released (stable) >=v0.1.18
Object landmark(keypoints) was released (stable) >=v0.1.18

Requirements

  • tensorflow-gpu==1.13.1 or 1.14.0 (pip3 install tensorflow-gpu)
  • comdutils (pip3 install comdutils)
  • simple-tensor (pip3 install simple-tensor)
  • numpy
  • opencv-python=3.4.2

Package Installation

  • pip3 install -r requirements.txt
  • pip3 install qoalai

Available Docker

  • download the docker image
  • docker pull ...

How To Use

Image Segmentation

import tensorflow as tf
from qoalai.segmentations.deeplab_resnet import DeepLab 

segmentation = DeepLab(num_classes=1, is_training=True)
model_path = '/home/model/resnet_v2_101/resnet_v2_101.ckpt'
# ---------------------------------- #
# calculate loss, using soft dice    #
# ---------------------------------- #
segmentation.cost = segmentation.soft_dice_loss(segmentation.target, segmentation.output)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    segmentation.optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(segmentation.cost)
# ---------------------------------- #
# tensorflow saver                   #
# ---------------------------------- #
segmentation.saver_partial = tf.train.Saver(var_list=segmentation.base_vars)
segmentation.saver_all = tf.train.Saver()
segmentation.session = tf.Session()
segmentation.session.run(tf.global_variables_initializer())
try:
    segmentation.saver_all.restore(segmentation.session, model_path)
except:
    segmentation.saver_partial.restore(segmentation.session, model_path)

# ---------------------------------- #
# dataset generator                  #
# ---------------------------------- #
train_generator = segmentation.batch_generator(batch_size=1, 
                                               dataset_path='/home/dataset/part_segmentation/', message='TRAIN')
val_generator = segmentation.batch_generator(batch_size=1, 
                                             dataset_path='/home/dataset/part_segmentation/', message='VAL')

# train
segmentation.optimize(subdivisions=2, 
                      iterations = 10000, 
                      best_loss= 1000000, 
                      train_batch=train_generator, 
                      val_batch=val_generator, 
                      save_path='/home/model/melon_segmentation/v0')

Object Keypoints (Landmark)

import tensorflow as tf 
from qoalai.landmarks.landmark_v1 import Landmark 
from simple_tensor.tensor_losses import mse_loss_mean

lm = Landmark(num_landmark_point=4,
              input_height = 300,
              input_width = 300, 
              input_channel = 3)

out = lm.build_densenet_base(input_tensor=lm.input_placeholder,
                    dropout_rate=0.15,
                    is_training=True,
                    top_layer_depth=128)

cost = mse_loss_mean(out, lm.output_placeholder)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost)

saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())

train_generator = lm.batch_generator(batch_size=12, dataset_path='/home/dataset/phone_landmark/train/', message='TRAIN')
val_generator = lm.batch_generator(batch_size=12, dataset_path='/home/dataset/phone_landmark/val/', message='VAL')

lm.optimize(iteration=10, 
            subdivition=3,
            cost_tensor=cost,
            optimizer_tensor=optimizer,
            out_tensor=out, 
            session=session,
            saver=saver, 
            train_generator=train_generator,
            val_generator=train_generator,
            best_loss=1000,
            path_tosave_model='model/model1')

Project details


Release history Release notifications | RSS feed

This version

0.2.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

qoalai-0.2.1.tar.gz (13.9 kB view hashes)

Uploaded Source

Built Distribution

qoalai-0.2.1-py3.5.egg (41.7 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page