You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2019/04/18 17:57:43 UTC

[GitHub] [madlib] kaknikhil commented on a change in pull request #370: DL: Support response and prob prediction outputs

kaknikhil commented on a change in pull request #370: DL: Support response and prob prediction outputs
URL: https://github.com/apache/madlib/pull/370#discussion_r276448537
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
 ##########
 @@ -27,107 +27,113 @@ from keras.models import *
 from keras.optimizers import *
 import numpy as np
 
+from madlib_keras_helper import expand_input_dims
+from madlib_keras_helper import PredictParamsProcessor
+from madlib_keras_helper import MODEL_DATA_CNAME
+from madlib_keras_wrapper import compile_and_set_weights
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
-from utilities.validate_args import get_col_value_and_type
+from utilities.utilities import create_cols_from_array_sql_string
+from utilities.utilities import unique_string
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
-from madlib_keras_validator import CLASS_VALUES_COLNAME
-from keras_model_arch_table import Format
 
-from madlib_keras_wrapper import compile_and_set_weights
 import madlib_keras_serializer
 
 MODULE_NAME = 'madlib_keras_predict'
+
+def validate_pred_type(pred_type, class_values):
+    if not pred_type in ['prob', 'response']:
+        plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
+            "either response or prob.".format(MODULE_NAME, pred_type))
+    if pred_type == 'prob' and class_values and len(class_values)+1 >= 1600:
+        plpy.error({"{0}: The output will have {1} columns, exceeding the "\
+            " max number of columns that can be created (1600)".format(
+                MODULE_NAME, len(class_values)+1)})
+
 def predict(schema_madlib, model_table, test_table, id_col,
-            independent_varname, output_table, **kwargs):
+            independent_varname, output_table, pred_type, **kwargs):
+    # Refactor and add more validation as part of MADLIB-1312.
     input_tbl_valid(model_table, MODULE_NAME)
-    model_summary_table = add_postfix(model_table, '_summary')
-    input_tbl_valid(model_summary_table, MODULE_NAME)
     input_tbl_valid(test_table, MODULE_NAME)
     output_tbl_valid(output_table, MODULE_NAME)
-    model_summary_dict = plpy.execute("SELECT * FROM {0}".format(
-        model_summary_table))[0]
-    model_arch_table = model_summary_dict['model_arch_table']
-    model_arch_id = model_summary_dict['model_arch_id']
-    compile_params = model_summary_dict['compile_params']
-    input_tbl_valid(model_arch_table, MODULE_NAME)
-
-    model_data_query = "SELECT model_data from {0}".format(model_table)
-    model_data = plpy.execute(model_data_query)[0]['model_data']
+    param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
 
-    model_arch_query = """
-        SELECT {0}, {1}
-        FROM {2}
-        WHERE {3} = {4}
-        """.format(Format.MODEL_ARCH, Format.MODEL_WEIGHTS,model_arch_table,
-                   Format.MODEL_ID, model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("{0}: No model arch found in table {1} with id {2}".format(
-            MODULE_NAME, model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result[Format.MODEL_ARCH]
+    class_values = param_proc.get_class_values()
+    compile_params = param_proc.get_compile_params()
+    dependent_varname = param_proc.get_dependent_varname()
+    dependent_vartype = param_proc.get_dependent_vartype()
+    model_data = param_proc.get_model_data()
+    model_arch = param_proc.get_model_arch()
+    normalizing_const = param_proc.get_normalizing_const()
+    # TODO: Validate input shape as part of MADLIB-1312
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
-    model_summary_table = add_postfix(model_table, "_summary")
-    class_values, _ = get_col_value_and_type(model_summary_table,
-                                             CLASS_VALUES_COLNAME)
-    predict_query = plpy.prepare("""
+
+    validate_pred_type(pred_type, class_values)
+    is_response = True if pred_type == 'response' else False
+    intermediate_col = unique_string()
+    if is_response:
+        pred_col_name = add_postfix("estimated_", dependent_varname)
+        pred_col_type = dependent_vartype
+    else:
+        pred_col_name = "prob"
+        pred_col_type = 'double precision'
+
+    num_of_valid_class_values = 0
+    if class_values is not None:
+        for ele in class_values:
+            if ele is None and num_of_valid_class_values > 0:
+                break
 
 Review comment:
   We should encapsulate this class_values logic in a function and name it appropriately and also add docstring. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services