You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/04/09 23:32:29 UTC

[madlib] branch master updated: DL: Simplify madlib_keras_predict interface

This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 57b6e50  DL: Simplify madlib_keras_predict interface
57b6e50 is described below

commit 57b6e50a76ddb3e43ee81006956ef6e7660bec5b
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Tue Apr 9 14:00:40 2019 -0700

    DL: Simplify madlib_keras_predict interface
    
    JIRA: MADLIB-1316
    
    This commit removes some unnecessary parameters from the predict
    function for deep learning. These params are now inferred from the model
    summary table instead.
    Note that this PR does very basic input validation for predict function.
    There is a JIRA (https://issues.apache.org/jira/browse/MADLIB-1321) to
    do that and other related refactoring.
    
    Closes #366
---
 .../modules/deep_learning/madlib_keras.sql_in      |  6 ---
 .../deep_learning/madlib_keras_predict.py_in       | 44 +++++++++++++---------
 .../modules/deep_learning/test/madlib_keras.sql_in |  6 ---
 3 files changed, 26 insertions(+), 30 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index e4a8534..9bc462f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -171,10 +171,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
     model_table             VARCHAR,
     test_table              VARCHAR,
     id_col                  VARCHAR,
-    model_arch_table        VARCHAR,
-    model_arch_id           INTEGER,
     independent_varname     VARCHAR,
-    compile_params          VARCHAR,
     output_table            VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
@@ -183,10 +180,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
                model_table,
                test_table,
                id_col,
-               model_arch_table,
-               model_arch_id,
                independent_varname,
-               compile_params,
                output_table)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index bf14d1e..85cd6c3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -37,24 +37,33 @@ from madlib_keras_wrapper import convert_string_of_args_to_dict
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import KerasWeightsSerializer
 
-def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
-            model_arch_id, independent_varname, compile_params, output_table,
-            **kwargs):
-    module_name = 'madlib_keras_predict'
-    input_tbl_valid(test_table, module_name)
-    input_tbl_valid(model_arch_table, module_name)
-    output_tbl_valid(output_table, module_name)
-
-    # _validate_input_args(test_table, model_arch_table, output_table)
+MODULE_NAME = 'madlib_keras_predict'
+def predict(schema_madlib, model_table, test_table, id_col,
+            independent_varname, output_table, **kwargs):
+    input_tbl_valid(model_table, MODULE_NAME)
+    model_summary_table = add_postfix(model_table, '_summary')
+    input_tbl_valid(model_summary_table, MODULE_NAME)
+    input_tbl_valid(test_table, MODULE_NAME)
+    output_tbl_valid(output_table, MODULE_NAME)
+    model_summary_dict = plpy.execute("SELECT * FROM {0}".format(
+        model_summary_table))[0]
+    model_arch_table = model_summary_dict['model_arch_table']
+    model_arch_id = model_summary_dict['model_arch_id']
+    compile_params = model_summary_dict['compile_params']
+    input_tbl_valid(model_arch_table, MODULE_NAME)
 
     model_data_query = "SELECT model_data from {0}".format(model_table)
     model_data = plpy.execute(model_data_query)[0]['model_data']
 
-    model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
-                       "WHERE id = {1}".format(model_arch_table, model_arch_id)
+    model_arch_query = """
+        SELECT model_arch, model_weights
+        FROM {0}
+        WHERE id = {1}
+        """.format(model_arch_table, model_arch_id)
     query_result = plpy.execute(model_arch_query)
     if not  query_result or len(query_result) == 0:
-        plpy.error("no model arch found in table {0} with id {1}".format(model_arch_table, model_arch_id))
+        plpy.error("{0}: No model arch found in table {1} with id {2}".format(
+            MODULE_NAME, model_arch_table, model_arch_id))
     query_result = query_result[0]
     model_arch = query_result['model_arch']
     input_shape = get_input_shape(model_arch)
@@ -112,12 +121,11 @@ def _get_class_label(class_values, class_index):
     if not class_values:
         return class_index
     elif class_index != int(class_index):
-        plpy.error("Invalid class index {0} returned from Keras predict. "\
-                   "Index value must be an integer".format(
-                    class_index))
+        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
+            " Index value must be an integer".format(MODULE_NAME, class_index))
     elif class_index < 0 or class_index >= len(class_values):
-        plpy.error("Invalid class index {0} returned from Keras predict. "\
-                   "Index value must be less than {1}".format(
-                    class_index, len(class_values)))
+        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
+            " Index value must be less than {2}".format(
+                MODULE_NAME, class_index, len(class_values)))
     else:
         return class_values[class_index]
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index a1743c9..ca6f9b6 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -172,10 +172,7 @@ SELECT madlib_keras_predict(
     'keras_saved_out',
     'cifar_10_sample',
     'id',
-    'model_arch',
-    1,
     'x',
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     'cifar10_predict');
 
 -- Validate that prediction output table exists and has correct schema
@@ -202,9 +199,6 @@ select assert(trap_error($TRAP$madlib_keras_predict(
     'keras_saved_out',
     'cifar_10_sample_batched',
     'id',
-    'model_arch',
-    1,
     'x',
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     'cifar10_predict');$TRAP$) = 1,
     'Passing batched image table to predict should error out.');