You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/04/18 23:00:21 UTC

[madlib] branch master updated (8feb9bf -> f06f2d8)

This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


    from 8feb9bf  WCC: Update the design doc
     new 4b8ceb2  DL: Support response and prob prediction outputs.
     new 3bf6da9  Utilities: Add unit tests for create_cols_from_array_sql_string()
     new f437659  DL: Remove reshaping and hard-coded normalizing_const from predict
     new f06f2d8  DL: Code refactor

The 4 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../modules/deep_learning/madlib_keras.py_in       |  26 +-
 .../modules/deep_learning/madlib_keras.sql_in      |  23 +-
 .../deep_learning/madlib_keras_helper.py_in        |  30 +-
 .../deep_learning/madlib_keras_predict.py_in       | 194 +++++++-----
 .../deep_learning/madlib_keras_validator.py_in     |  23 +-
 .../deep_learning/predict_input_params.py_in       |  85 ++++++
 .../modules/deep_learning/test/madlib_keras.sql_in | 339 ++++++++++++++++++++-
 .../test/unit_tests/test_madlib_keras.py_in        |  71 +++--
 .../utilities/test/unit_tests/test_utilities.py_in | 125 ++++++++
 .../postgres/modules/utilities/utilities.py_in     |  92 +++++-
 10 files changed, 863 insertions(+), 145 deletions(-)
 copy deploy/DEB/preinst => src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in (53%)
 create mode 100644 src/ports/postgres/modules/deep_learning/predict_input_params.py_in


[madlib] 02/04: Utilities: Add unit tests for create_cols_from_array_sql_string()

Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 3bf6da956e05a3df052807d3e5c784681287e0cb
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Apr 11 17:13:41 2019 -0700

    Utilities: Add unit tests for create_cols_from_array_sql_string()
    
    JIRA: MADLIB-1315
    
    Closes #370
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
---
 .../utilities/test/unit_tests/test_utilities.py_in | 125 +++++++++++++++++++++
 1 file changed, 125 insertions(+)

diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
index 2d2c481..b884eec 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
@@ -247,5 +247,130 @@ class UtilitiesTestCase(unittest.TestCase):
         self.assertTrue(s.is_valid_psql_type('boolean[]', s.INTEGER | s.ANY_ARRAY))
         self.assertFalse(s.is_valid_psql_type('boolean', s.ANY_ARRAY))
 
+    def test_create_cols_from_array_sql_string_empty_pylist(self):
+        utils = self.subject
+        self.py_list = None
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'estimated_col'
+        self.coltype = 'dummy'
+        self.has_one_ele = True
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertEqual(out_sql, 'sqlcol[1]+1 AS estimated_col')
+        self.has_one_ele = False
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertEqual(out_sql, 'sqlcol AS estimated_col')
+
+    def test_create_cols_from_array_sql_string_one_ele(self):
+        utils = self.subject
+        self.py_list = ['cat', 'dog']
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'estimated_pred'
+        self.coltype = 'TEXT'
+        self.has_one_ele = True
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertTrue(out_sql, "(ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred")
+
+    def test_create_cols_from_array_sql_string_one_ele_with_NULL(self):
+        utils = self.subject
+        self.py_list = [None, 1, 2]
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'estimated_pred'
+        self.coltype = 'INTEGER'
+        self.has_one_ele = True
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertEqual(out_sql, "(ARRAY[ NULL,1,2 ]::INTEGER[])[sqlcol[1]+1]::INTEGER AS estimated_pred")
+
+    def test_create_cols_from_array_sql_string_one_ele_with_many_NULL(self):
+        utils = self.subject
+        self.py_list = [None, 'cat', 'dog', None, None]
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'estimated_pred'
+        self.coltype = 'TEXT'
+        self.has_one_ele = True
+        with self.assertRaises(plpy.PLPYException):
+            utils.create_cols_from_array_sql_string(
+                self.py_list, self.sql_array_col, self.colname, self.coltype,
+                self.has_one_ele, "dummy_module")
+
+    def test_create_cols_from_array_sql_string_many_ele(self):
+        utils = self.subject
+        self.py_list = ['cat', 'dog']
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'prob'
+        self.coltype = 'TEXT'
+        self.has_one_ele = False
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_cat\", CAST(sqlcol[2] AS TEXT) AS \"prob_dog\"")
+
+    def test_create_cols_from_array_sql_string_many_ele_with_NULL(self):
+        utils = self.subject
+        self.py_list = [None, 'cat', 'dog']
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'prob'
+        self.coltype = 'TEXT'
+        self.has_one_ele = False
+        out_sql = utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+        self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_NULL\", CAST(sqlcol[2] AS TEXT) AS \"prob_cat\", CAST(sqlcol[3] AS TEXT) AS \"prob_dog\"")
+
+    def test_create_cols_from_array_sql_string_many_ele_with_many_NULL(self):
+        utils = self.subject
+        self.py_list = [None, 'cat', 'dog', None, None]
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'prob'
+        self.coltype = 'TEXT'
+        self.has_one_ele = False
+        with self.assertRaises(plpy.PLPYException):
+            utils.create_cols_from_array_sql_string(
+            self.py_list, self.sql_array_col, self.colname, self.coltype,
+            self.has_one_ele, "dummy_module")
+
+    def test_create_cols_from_array_sql_string_invalid_sql_array(self):
+        utils = self.subject
+        self.py_list = ['cat', 'dog']
+        self.sql_array_col = None
+        self.colname = 'prob'
+        self.coltype = 'TEXT'
+        self.has_one_ele = False
+        with self.assertRaises(plpy.PLPYException):
+            utils.create_cols_from_array_sql_string(
+                self.py_list, self.sql_array_col, self.colname, self.coltype,
+                self.has_one_ele, "dummy_module")
+
+    def test_create_cols_from_array_sql_string_invalid_colname(self):
+        utils = self.subject
+        self.py_list = ['cat', 'dog']
+        self.sql_array_col = 'sqlcol'
+        self.colname = ''
+        self.coltype = 'TEXT'
+        self.has_one_ele = False
+        with self.assertRaises(plpy.PLPYException):
+            utils.create_cols_from_array_sql_string(
+                self.py_list, self.sql_array_col, self.colname, self.coltype,
+                self.has_one_ele, "dummy_module")
+
+    def test_create_cols_from_array_sql_string_invalid_coltype(self):
+        utils = self.subject
+        self.py_list = ['cat', 'dog']
+        self.sql_array_col = 'sqlcol'
+        self.colname = 'prob'
+        self.coltype = ''
+        self.has_one_ele = False
+        with self.assertRaises(plpy.PLPYException):
+            utils.create_cols_from_array_sql_string(
+                self.py_list, self.sql_array_col, self.colname, self.coltype,
+                self.has_one_ele, "dummy_module")
+
 if __name__ == '__main__':
     unittest.main()


[madlib] 04/04: DL: Code refactor

Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f06f2d8fa4442b0bf549bef8f9a9ac1c070c8e5e
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Apr 18 14:26:15 2019 -0700

    DL: Code refactor
    
    JIRA: MADLIB-1315
    Refactor code and address comments for PR #370.
    
    Closes #370
---
 .../deep_learning/madlib_keras_helper.py_in        | 59 +---------------------
 .../deep_learning/madlib_keras_predict.py_in       | 46 +++++++++++++----
 ...ras_helper.py_in => predict_input_params.py_in} | 26 +++-------
 .../test/unit_tests/test_madlib_keras.py_in        | 17 +++++++
 4 files changed, 63 insertions(+), 85 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index d56a0e3..445b5b9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -18,12 +18,8 @@
 # under the License.
 
 import numpy as np
-import plpy
-from keras_model_arch_table import Format
-from utilities.utilities import add_postfix
-from utilities.validate_args import input_tbl_valid
 
-# Prepend 1 to np arrays using expand_dims.
+# Prepend a dimension to np arrays using expand_dims.
 def expand_input_dims(input_data, target_type=None):
     input_data = np.array(input_data)
     input_data = np.expand_dims(input_data, axis=0)
@@ -40,56 +36,3 @@ DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
 MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
 MODEL_ARCH_ID_COLNAME = "model_arch_id"
 MODEL_DATA_COLNAME = "model_data"
-
-class PredictParamsProcessor:
-    def __init__(self, model_table, module_name):
-        self.module_name = module_name
-        self.model_table = model_table
-        self.model_summary_table = add_postfix(self.model_table, '_summary')
-        input_tbl_valid(self.model_summary_table, self.module_name)
-        self.model_summary_dict = self._get_model_summary_dict()
-        self.model_arch_dict = self._get_model_arch_dict()
-
-    def _get_model_summary_dict(self):
-        return plpy.execute("SELECT * FROM {0}".format(
-            self.model_summary_table))[0]
-
-    def _get_model_arch_dict(self):
-        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_COLNAME]
-        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_COLNAME]
-        input_tbl_valid(model_arch_table, self.module_name)
-        model_arch_query = """
-            SELECT {0}
-            FROM {1}
-            WHERE {2} = {3}
-        """.format(Format.MODEL_ARCH, model_arch_table, Format.MODEL_ID,
-                   model_arch_id)
-        query_result = plpy.execute(model_arch_query)
-        if not query_result or len(query_result) == 0:
-            plpy.error("{0}: No model arch found in table {1} with id {2}".format(
-                self.module_name, model_arch_table, model_arch_id))
-        return query_result[0]
-
-    def get_class_values(self):
-        return self.model_summary_dict[CLASS_VALUES_COLNAME]
-
-    def get_compile_params(self):
-        return self.model_summary_dict[COMPILE_PARAMS_COLNAME]
-
-    def get_dependent_varname(self):
-        return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME]
-
-    def get_dependent_vartype(self):
-        return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME]
-
-    def get_model_arch(self):
-        return self.model_arch_dict[Format.MODEL_ARCH]
-
-    def get_model_data(self):
-        return plpy.execute("""
-                SELECT {0} FROM {1}
-            """.format(MODEL_DATA_COLNAME, self.model_table)
-                            )[0][MODEL_DATA_COLNAME]
-
-    def get_normalizing_const(self):
-        return self.model_summary_dict[NORMALIZING_CONST_COLNAME]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 3108be5..95ae2cf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -28,9 +28,9 @@ from keras.optimizers import *
 import numpy as np
 
 from madlib_keras_helper import expand_input_dims
-from madlib_keras_helper import PredictParamsProcessor
 from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_wrapper import compile_and_set_weights
+from predict_input_params import PredictParamsProcessor
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
 from utilities.utilities import create_cols_from_array_sql_string
@@ -51,6 +51,41 @@ def validate_pred_type(pred_type, class_values):
             " max number of columns that can be created (1600)".format(
                 MODULE_NAME, len(class_values)+1)})
 
+def _strip_trailing_nulls_from_class_values(class_values):
+    """
+        class_values is a list of unique class levels in training data. This
+        could have multiple Nones in it, and this function strips out all the
+        Nones that occur after the first element in the list.
+        Examples:
+            1) input class_values = ['cat', 'dog']
+               output class_values = ['cat', 'dog']
+
+            2) input class_values = [None, 'cat', 'dog']
+               output class_values = [None, 'cat', 'dog']
+
+            3) input class_values = [None, 'cat', 'dog', None, None]
+               output class_values = [None, 'cat', 'dog']
+
+            4) input class_values = ['cat', 'dog', None, None]
+               output class_values = ['cat', 'dog']
+
+            5) input class_values = [None, None]
+               output class_values = [None]
+        @args:
+            @param: class_values, list
+        @returns:
+            updated class_values list
+    """
+    num_of_valid_class_values = 0
+    if class_values is not None:
+        for ele in class_values:
+            if ele is None and num_of_valid_class_values > 0:
+                break
+            num_of_valid_class_values += 1
+        # Pass only the valid class_values for creating columns
+        class_values = class_values[:num_of_valid_class_values]
+    return class_values
+
 def predict(schema_madlib, model_table, test_table, id_col,
             independent_varname, output_table, pred_type, **kwargs):
     # Refactor and add more validation as part of MADLIB-1312.
@@ -80,14 +115,7 @@ def predict(schema_madlib, model_table, test_table, id_col,
         pred_col_name = "prob"
         pred_col_type = 'double precision'
 
-    num_of_valid_class_values = 0
-    if class_values is not None:
-        for ele in class_values:
-            if ele is None and num_of_valid_class_values > 0:
-                break
-            num_of_valid_class_values += 1
-        # Pass only the valid class_values for creating columns
-        class_values = class_values[:num_of_valid_class_values]
+    class_values = _strip_trailing_nulls_from_class_values(class_values)
 
     prediction_select_clause = create_cols_from_array_sql_string(
         class_values, intermediate_col, pred_col_name,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
similarity index 81%
copy from src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
copy to src/ports/postgres/modules/deep_learning/predict_input_params.py_in
index d56a0e3..69ee961 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
@@ -17,29 +17,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import numpy as np
 import plpy
 from keras_model_arch_table import Format
 from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 
-# Prepend 1 to np arrays using expand_dims.
-def expand_input_dims(input_data, target_type=None):
-    input_data = np.array(input_data)
-    input_data = np.expand_dims(input_data, axis=0)
-    if target_type:
-        input_data = input_data.astype(target_type)
-    return input_data
-
-# Name of columns in model summary table.
-CLASS_VALUES_COLNAME = "class_values"
-NORMALIZING_CONST_COLNAME = "normalizing_const"
-COMPILE_PARAMS_COLNAME = "compile_params"
-DEPENDENT_VARNAME_COLNAME = "dependent_varname"
-DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
-MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
-MODEL_ARCH_ID_COLNAME = "model_arch_id"
-MODEL_DATA_COLNAME = "model_data"
+from madlib_keras_helper import CLASS_VALUES_COLNAME
+from madlib_keras_helper import COMPILE_PARAMS_COLNAME
+from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
+from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
+from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
+from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
+from madlib_keras_helper import MODEL_DATA_COLNAME
+from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 
 class PredictParamsProcessor:
     def __init__(self, model_table, module_name):
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 533e347..e6b09e4 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -402,6 +402,23 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
         self.subject.validate_pred_type('response', range(1598))
         self.subject.validate_pred_type('response', None)
 
+    def test_strip_trailing_nulls_from_class_values(self):
+        self.assertEqual(['cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                ['cat', 'dog']))
+        self.assertEqual([None, 'cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, 'cat', 'dog']))
+        self.assertEqual([None, 'cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, 'cat', 'dog', None, None]))
+        self.assertEqual(['cat', 'dog'],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                ['cat', 'dog', None, None]))
+        self.assertEqual([None],
+                         self.subject._strip_trailing_nulls_from_class_values(
+                [None, None]))
+
 class MadlibKerasHelperTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')


[madlib] 01/04: DL: Support response and prob prediction outputs.

Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 4b8ceb27752a011ee09efd0a3d26ef279ac84a31
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Tue Apr 9 17:40:55 2019 -0700

    DL: Support response and prob prediction outputs.
    
    JIRA: MADLIB-1315
    
    This commit adds a new optional parameter to madlib_keras_predict called
    pred_type.
    1) If pred_type='response', then the prediction output would have a new
    column called estimated_COLNAME, and its value would be the class label
    with highest probability.
    2) If pred_type='prob', then the prediction output will have multiple
    columns, one for each class label. The values in these columns would be
    the probability associated with the corresponding class label.
    
    This commit also adds dev-checks for these scenarios.
    
    Closes #370
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
---
 .../modules/deep_learning/madlib_keras.py_in       |  16 +-
 .../modules/deep_learning/madlib_keras.sql_in      |  21 +-
 .../deep_learning/madlib_keras_helper.py_in        |  87 +++++++
 .../deep_learning/madlib_keras_predict.py_in       | 151 ++++++------
 .../deep_learning/madlib_keras_validator.py_in     |  27 ++-
 .../modules/deep_learning/test/madlib_keras.sql_in | 270 ++++++++++++++++++++-
 .../test/unit_tests/test_madlib_keras.py_in        |  34 +--
 .../postgres/modules/utilities/utilities.py_in     |  92 ++++++-
 8 files changed, 573 insertions(+), 125 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index a69cc6e..afe2187 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -36,9 +36,9 @@ from keras.models import *
 from keras.optimizers import *
 from keras.regularizers import *
 import madlib_keras_serializer
-from madlib_keras_validator import CLASS_VALUES_COLNAME
-from madlib_keras_validator import DEPENDENT_VARTYPE
-from madlib_keras_validator import NORMALIZING_CONST_COLNAME
+from madlib_keras_helper import CLASS_VALUES_CNAME
+from madlib_keras_helper import DEPENDENT_VARTYPE_CNAME
+from madlib_keras_helper import NORMALIZING_CONST_CNAME
 from madlib_keras_validator import FitInputValidator
 from madlib_keras_wrapper import *
 from keras_model_arch_table import Format
@@ -196,11 +196,11 @@ def fit(schema_madlib, source_table, model, dependent_varname,
         final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
     class_values, class_values_type = get_col_value_and_type(
-        fit_validator.source_summary_table, CLASS_VALUES_COLNAME)
+        fit_validator.source_summary_table, CLASS_VALUES_CNAME)
     norm_const, norm_const_type = get_col_value_and_type(
-        fit_validator.source_summary_table, NORMALIZING_CONST_COLNAME)
+        fit_validator.source_summary_table, NORMALIZING_CONST_CNAME)
     dep_vartype = plpy.execute("SELECT {0} AS dep FROM {1}".format(
-        DEPENDENT_VARTYPE, fit_validator.source_summary_table))[0]['dep']
+        DEPENDENT_VARTYPE_CNAME, fit_validator.source_summary_table))[0]['dep']
     create_output_summary_table = plpy.prepare("""
         CREATE TABLE {0}_summary AS
         SELECT
@@ -234,8 +234,8 @@ def fit(schema_madlib, source_table, model, dependent_varname,
         $28 AS {1},
         $29 AS {2},
         $30 AS {3}
-        """.format(model, CLASS_VALUES_COLNAME, DEPENDENT_VARTYPE,
-                   NORMALIZING_CONST_COLNAME),
+        """.format(model, CLASS_VALUES_CNAME, DEPENDENT_VARTYPE_CNAME,
+                   NORMALIZING_CONST_CNAME),
                    ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
                     "TIMESTAMP", "TEXT", "TEXT","TEXT",
                     "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 37afad9..8ba24c7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -172,7 +172,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
     test_table              VARCHAR,
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
-    output_table            VARCHAR
+    output_table            VARCHAR,
+    pred_type               VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
@@ -181,17 +182,29 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
                test_table,
                id_col,
                independent_varname,
-               output_table)
+               output_table,
+               pred_type)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
+    model_table             VARCHAR,
+    test_table              VARCHAR,
+    id_col                  VARCHAR,
+    independent_varname     VARCHAR,
+    output_table            VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, 'response');
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
    independent_var double precision [],
    model_architecture TEXT,
    model_data bytea,
    input_shape integer[],
    compile_params TEXT,
-   class_values TEXT[]
+   is_response BOOLEAN
 ) RETURNS DOUBLE PRECISION[] AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
@@ -201,7 +214,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
                model_data,
                input_shape,
                compile_params,
-               class_values)
+               is_response)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
new file mode 100644
index 0000000..bc5e703
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -0,0 +1,87 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+from keras_model_arch_table import Format
+from utilities.utilities import add_postfix
+from utilities.validate_args import input_tbl_valid
+
+# Name of columns in model summary table.
+CLASS_VALUES_CNAME = "class_values"
+NORMALIZING_CONST_CNAME = "normalizing_const"
+DEPENDENT_VARTYPE_CNAME = "dependent_vartype"
+COMPILE_PARAMS_CNAME = "compile_params"
+DEPENDENT_VARNAME_CNAME = "dependent_varname"
+DEPENDENT_VARTYPE_CNAME = "dependent_vartype"
+MODEL_ARCH_TABLE_CNAME = "model_arch_table"
+MODEL_ARCH_ID_CNAME = "model_arch_id"
+MODEL_DATA_CNAME = "model_data"
+
+class PredictParamsProcessor:
+    def __init__(self, model_table, module_name):
+        self.module_name = module_name
+        self.model_table = model_table
+        self.model_summary_table = add_postfix(self.model_table, '_summary')
+        input_tbl_valid(self.model_summary_table, self.module_name)
+        self.model_summary_dict = self._get_model_summary_dict()
+        self.model_arch_dict = self._get_model_arch_dict()
+
+    def _get_model_summary_dict(self):
+        return plpy.execute("SELECT * FROM {0}".format(
+            self.model_summary_table))[0]
+
+    def _get_model_arch_dict(self):
+        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_CNAME]
+        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_CNAME]
+        input_tbl_valid(model_arch_table, self.module_name)
+        model_arch_query = """
+            SELECT {0}
+            FROM {1}
+            WHERE {2} = {3}
+        """.format(Format.MODEL_ARCH, model_arch_table, Format.MODEL_ID,
+                   model_arch_id)
+        query_result = plpy.execute(model_arch_query)
+        if not query_result or len(query_result) == 0:
+            plpy.error("{0}: No model arch found in table {1} with id {2}".format(
+                self.module_name, model_arch_table, model_arch_id))
+        return query_result[0]
+
+    def get_class_values(self):
+        return self.model_summary_dict[CLASS_VALUES_CNAME]
+
+    def get_compile_params(self):
+        return self.model_summary_dict[COMPILE_PARAMS_CNAME]
+
+    def get_dependent_varname(self):
+        return self.model_summary_dict[DEPENDENT_VARNAME_CNAME]
+
+    def get_dependent_vartype(self):
+        return self.model_summary_dict[DEPENDENT_VARTYPE_CNAME]
+
+    def get_model_arch(self):
+        return self.model_arch_dict[Format.MODEL_ARCH]
+
+    def get_model_data(self):
+        return plpy.execute("""
+                SELECT {0} FROM {1}
+            """.format(MODEL_DATA_CNAME, self.model_table)
+                            )[0][MODEL_DATA_CNAME]
+
+    def get_normalizing_const(self):
+        return self.model_summary_dict[NORMALIZING_CONST_CNAME]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 1180d33..34b26c3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -27,67 +27,89 @@ from keras.models import *
 from keras.optimizers import *
 import numpy as np
 
+from madlib_keras_helper import PredictParamsProcessor
+from madlib_keras_helper import MODEL_DATA_CNAME
+from madlib_keras_wrapper import compile_and_set_weights
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
-from utilities.validate_args import get_col_value_and_type
+from utilities.utilities import create_cols_from_array_sql_string
+from utilities.utilities import unique_string
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
-from madlib_keras_validator import CLASS_VALUES_COLNAME
-from keras_model_arch_table import Format
 
-from madlib_keras_wrapper import compile_and_set_weights
 import madlib_keras_serializer
 
 MODULE_NAME = 'madlib_keras_predict'
+
+def validate_pred_type(pred_type, class_values):
+    if not pred_type in ['prob', 'response']:
+        plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
+            "either response or prob.".format(MODULE_NAME, pred_type))
+    if pred_type == 'prob' and class_values and len(class_values)+1 >= 1600:
+        plpy.error({"{0}: The output will have {1} columns, exceeding the "\
+            " max number of columns that can be created (1600)".format(
+                MODULE_NAME, len(class_values)+1)})
+
 def predict(schema_madlib, model_table, test_table, id_col,
-            independent_varname, output_table, **kwargs):
+            independent_varname, output_table, pred_type, **kwargs):
+    # Refactor and add more validation as part of MADLIB-1312.
     input_tbl_valid(model_table, MODULE_NAME)
-    model_summary_table = add_postfix(model_table, '_summary')
-    input_tbl_valid(model_summary_table, MODULE_NAME)
     input_tbl_valid(test_table, MODULE_NAME)
     output_tbl_valid(output_table, MODULE_NAME)
-    model_summary_dict = plpy.execute("SELECT * FROM {0}".format(
-        model_summary_table))[0]
-    model_arch_table = model_summary_dict['model_arch_table']
-    model_arch_id = model_summary_dict['model_arch_id']
-    compile_params = model_summary_dict['compile_params']
-    input_tbl_valid(model_arch_table, MODULE_NAME)
+    param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
 
-    model_data_query = "SELECT model_data from {0}".format(model_table)
-    model_data = plpy.execute(model_data_query)[0]['model_data']
+    class_values = param_proc.get_class_values()
+    compile_params = param_proc.get_compile_params()
+    dependent_varname = param_proc.get_dependent_varname()
+    dependent_vartype = param_proc.get_dependent_vartype()
+    model_data = param_proc.get_model_data()
+    model_arch = param_proc.get_model_arch()
 
-    model_arch_query = """
-        SELECT {0}, {1}
-        FROM {2}
-        WHERE {3} = {4}
-        """.format(Format.MODEL_ARCH, Format.MODEL_WEIGHTS,model_arch_table,
-                   Format.MODEL_ID, model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("{0}: No model arch found in table {1} with id {2}".format(
-            MODULE_NAME, model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result[Format.MODEL_ARCH]
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
-    model_summary_table = add_postfix(model_table, "_summary")
-    class_values, _ = get_col_value_and_type(model_summary_table,
-                                             CLASS_VALUES_COLNAME)
-    predict_query = plpy.prepare("""
+
+    validate_pred_type(pred_type, class_values)
+    is_response = True if pred_type == 'response' else False
+    intermediate_col = unique_string()
+    if is_response:
+        pred_col_name = add_postfix("estimated_", dependent_varname)
+        pred_col_type = dependent_vartype
+    else:
+        pred_col_name = "prob"
+        pred_col_type = 'double precision'
+
+    num_of_valid_class_values = 0
+    if class_values is not None:
+        for ele in class_values:
+            if ele is None and num_of_valid_class_values > 0:
+                break
+            num_of_valid_class_values += 1
+        # Pass only the valid class_values for creating columns
+        class_values = class_values[:num_of_valid_class_values]
+
+    prediction_select_clause = create_cols_from_array_sql_string(
+        class_values, intermediate_col, pred_col_name,
+        pred_col_type, is_response, MODULE_NAME)
+
+    plpy.execute("""
         CREATE TABLE {output_table} AS
-        SELECT {id_col},
-            ({schema_madlib}.internal_keras_predict
-                ({independent_varname},
-                 $MAD${model_arch}$MAD$,
-                 $1,ARRAY{input_shape},
-                 {compile_params},
-                 ARRAY{class_values}::TEXT[])
-            )[1] as prediction
-        from {test_table}""".format(**locals()), ["bytea"])
-    plpy.execute(predict_query, [model_data])
+        SELECT {id_col}, {prediction_select_clause}
+        FROM (
+            SELECT {test_table}.{id_col},
+                   ({schema_madlib}.internal_keras_predict
+                       ({independent_varname},
+                        $MAD${model_arch}$MAD$,
+                        {0},
+                        ARRAY{input_shape},
+                        {compile_params},
+                        {is_response})
+                   ) AS {intermediate_col}
+        FROM {test_table}, {model_table}
+        ) q
+        """.format(MODEL_DATA_CNAME, **locals()))
 
 def internal_keras_predict(x_test, model_arch, model_data, input_shape,
-                           compile_params, class_values):
+                           compile_params, is_response):
     model = model_from_json(model_arch)
     device_name = '/cpu:0'
     os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
@@ -97,37 +119,16 @@ def internal_keras_predict(x_test, model_arch, model_data, input_shape,
 
     x_test = np.array(x_test).reshape(1, *input_shape)
     x_test /= 255
-    proba_argmax = model.predict_classes(x_test)
-    # proba_argmax is a list with exactly one element in it. That element
-    # refers to the index containing the largest probability value in the
-    # output of Keras' predict function.
-    return _get_class_label(class_values, proba_argmax[0])
-
-def _get_class_label(class_values, class_index):
-    """
-    Returns back the class label associated with the index returned by Keras'
-    predict_classes function. Keras' predict_classes function returns back
-    the index of the 1-hot encoded output that has the highest probability
-    value. We should infer the exact class label corresponding to the index
-    by looking at the class_values list (which is obtained from the
-    class_values column of the model summary table). If class_values is None,
-    we return the index as is.
-    Args:
-        @param class_values: list of class labels.
-        @param class_index: integer representing the index with max
-                            probability value.
-    Returns:
-        scalar. If class_values is None, returns class_index, else returns
-        class_values[class_index].
-    """
-    if not class_values:
-        return class_index
-    elif class_index != int(class_index):
-        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
-            " Index value must be an integer".format(MODULE_NAME, class_index))
-    elif class_index < 0 or class_index >= len(class_values):
-        plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
-            " Index value must be less than {2}".format(
-                MODULE_NAME, class_index, len(class_values)))
+    if is_response:
+        proba_argmax = model.predict_classes(x_test)
+        # proba_argmax is a list with exactly one element in it. That element
+        # refers to the index containing the largest probability value in the
+        # output of Keras' predict function.
+        return proba_argmax
     else:
-        return class_values[class_index]
+        probs = model.predict_proba(x_test)
+        # probs is a list containing a list of probability values, of all
+        # class levels. Since we are assuming each input is a single image,
+        # and not mini-batched, this list contains exactly one list in it,
+        # so return back the first list in probs.
+        return probs[0]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 52e7d20..481fb1b 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -1,3 +1,23 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from madlib_keras_helper import CLASS_VALUES_CNAME
 from utilities.minibatch_validation import validate_dependent_var_for_minibatch
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
@@ -9,9 +29,6 @@ from utilities.validate_args import get_expr_type
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 import plpy
-CLASS_VALUES_COLNAME = "class_values"
-NORMALIZING_CONST_COLNAME = "normalizing_const"
-DEPENDENT_VARTYPE = "dependent_vartype"
 
 class FitInputValidator:
     def __init__(self, source_table, validation_table, output_model_table,
@@ -55,11 +72,11 @@ class FitInputValidator:
         input_tbl_valid(self.source_table, self.module_name)
         input_tbl_valid(self.source_summary_table, self.module_name)
         _assert(is_var_valid(
-            self.source_summary_table, CLASS_VALUES_COLNAME),
+            self.source_summary_table, CLASS_VALUES_CNAME),
                 "model_keras error: invalid class_values varname "
                 "('{class_values}') for source_summary_table "
                 "({source_summary_table}).".format(
-                    class_values=CLASS_VALUES_COLNAME,
+                    class_values=CLASS_VALUES_CNAME,
                     source_summary_table=self.source_summary_table))
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 8b7bd0c..a259630 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -145,6 +145,7 @@ SELECT assert(
     validation_table is NULL AND
     model = 'keras_out' AND
     dependent_varname = 'dependent_var' AND
+    dependent_vartype = 'smallint' AND
     independent_varname = 'independent_var' AND
     name = 'model name' AND
     description = 'model desc' AND
@@ -183,19 +184,17 @@ SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be I
         AND attname = 'id';
 
 SELECT assert(UPPER(atttypid::regtype::TEXT) =
-    'DOUBLE PRECISION', 'prediction column should be DOUBLE PRECISION type')
+    'SMALLINT', 'prediction column should be SMALLINT type')
     FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
-        AND attname = 'prediction';
+        AND attname = 'estimated_dependent_var';
 
 -- Validate correct number of rows returned.
 SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have two rows') FROM cifar10_predict;
 
 -- First test that all values are in set of class values; if this breaks, it's definitely a problem.
-SELECT assert(prediction in (0,1),'Predicted value not in set of defined class values for model') FROM cifar10_predict;
-
--- Then test that each of the two images is correctly predicted.  If this breaks, it's likely a different problem.
-SELECT assert(prediction=0,'Incorrect prediction for first image.  Predicted: ' || __to_char(prediction) || ', Expected: 0') FROM cifar10_predict WHERE id=1;
-SELECT assert(prediction=1,'Incorrect prediction for second image.  Predicted: ' || __to_char(prediction) || ', Expected: 1') FROM cifar10_predict WHERE id=2;
+SELECT assert(estimated_dependent_var IN (0,1),
+              'Predicted value not in set of defined class values for model')
+FROM cifar10_predict;
 
 select assert(trap_error($TRAP$madlib_keras_predict(
     'keras_saved_out',
@@ -220,7 +219,7 @@ SELECT madlib_keras_fit(
     NULL,
     'model name', 'model desc');
 
--- -- negative test case for passing non numeric y to fit
+-- negative test case for passing non numeric y to fit
 -- induce failure by passing a non numeric column
 create table cifar_10_sample_val_failure as select * from cifar_10_sample_val;
 alter table cifar_10_sample_val_failure rename dependent_var to dependent_var_original;
@@ -240,3 +239,258 @@ select assert(trap_error($TRAP$madlib_keras_fit(
           'cifar_10_sample_val_failure');$TRAP$) = 1,
        'Passing y of type non numeric array to fit should error out.');
 
+-- Test with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob');
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_0 should be double precision type')
+    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+        AND attname = 'prob_0';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_1 should be double precision type')
+    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+        AND attname = 'prob_1';
+
+SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Tests with text class values:
+-- Modify input data to have text classes, and mini-batch it.
+CREATE TABLE cifar_10_sample_text AS
+SELECT * FROM cifar_10_sample;
+
+ALTER TABLE cifar_10_sample_text ALTER COLUMN y type TEXT;
+UPDATE cifar_10_sample_text SET y='cat' where y='0';
+UPDATE cifar_10_sample_text SET y='dog' where y='1';
+-- Add a new image with NULL class value
+INSERT INTO cifar_10_sample_text(id, x, y, imgpath)
+SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample_text
+WHERE y='cat';
+
+
+DROP TABLE IF EXISTS cifar_10_sample_text_batched;
+DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
+SELECT minibatch_preprocessor_dl('cifar_10_sample_text','cifar_10_sample_text_batched','y','x', 2, 255, 5);
+
+-- Change model_arch to reflect 5 num_classes
+DROP TABLE IF EXISTS model_arch;
+SELECT load_keras_model('model_arch',
+  $${
+  "class_name": "Sequential",
+  "keras_version": "2.1.6",
+  "config": [{
+    "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}},
+    "name": "conv2d_1",
+    "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null,
+    "dtype": "float32", "activation": "relu", "trainable": true,
+    "data_format": "channels_last", "filters": 32, "padding": "valid",
+    "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
+    "bias_initializer": {"class_name": "Zeros", "config": {}},
+    "batch_input_shape": [null, 32, 32, 3], "use_bias": true,
+    "activity_regularizer": null, "kernel_size": [3, 3]}},
+    {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", "trainable": true, "data_format": "channels_last", "pool_size": [2, 2], "padding": "valid", "strides": [2, 2]}},
+    {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, "trainable": true, "seed": null, "name": "dropout_1"}},
+    {"class_name": "Flatten", "config": {"trainable": true, "name": "flatten_1", "data_format": "channels_last"}},
+    {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer":
+    {"class_name": "Zeros", "config": {}}, "units": 5, "use_bias": true, "activity_regularizer": null}
+    }], "backend": "tensorflow"}$$);
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_text_batched',
+    'keras_saved_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE);
+-- Assert fit has correct class_values
+SELECT assert(
+    dependent_vartype = 'text' AND
+    class_values = '{NULL,cat,dog,NULL,NULL}',
+    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob');
+
+-- Validate the output datatype of newly created prediction columns
+-- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
+-- class_values
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_cat should be double precision type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_cat';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_dog should be double precision type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_dog';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_NULL';
+
+-- Must have exactly 4 cols (3 for class_values and 1 for id)
+SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response');
+
+-- Validate the output datatype of newly created prediction columns
+-- for prediction type = 'response' and class_values 'TEXT' with NULL
+-- as a valid class_values
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'TEXT', 'prediction column should be TEXT type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass
+      AND attname = 'estimated_dependent_var';
+
+-- Tests where the assumption is user has one-hot encoded, so class_values
+-- in input summary table will be NULL.
+UPDATE keras_saved_out_summary SET class_values=NULL;
+
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob');
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_value = NULL
+-- Returns: Array of probabilities for user's one-hot encoded data
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION[]', 'column prob should be double precision[] type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob';
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_text',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response');
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_value = NULL
+-- Returns: Index of class value in user's one-hot encoded data with
+-- highest probability
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'prediction column should be double precision type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass
+      AND attname = 'estimated_dependent_var';
+
+-- Test predict with INTEGER class_values
+-- with NULL as a valid class value
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample
+WHERE y = 1;
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 4, x, 4, '0/img4.jpg' FROM cifar_10_sample
+WHERE y = 0;
+INSERT INTO cifar_10_sample(id, x, y, imgpath)
+SELECT 5, x, 5, '0/img5.jpg' FROM cifar_10_sample
+WHERE y = 1;
+
+DROP TABLE IF EXISTS cifar_10_sample_int_batched;
+DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
+SELECT minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x', 2, 255, 5);
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_int_batched',
+    'keras_saved_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE);
+
+-- Assert fit has correct class_values
+SELECT assert(
+    dependent_vartype = 'smallint' AND
+    class_values = '{NULL,0,1,4,5}',
+    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM keras_saved_out_summary) summary;
+
+-- Predict with pred_type=prob
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob');
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'prob' and class_values 'INT' with NULL
+-- as a valid class_values
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'prob_NULL';
+
+-- Must have exactly 6 cols (5 for class_values and 1 for id)
+SELECT assert(COUNT(*)=6, 'Predict out table must have exactly four cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Predict with pred_type=response
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    'response');
+
+-- Validate the output datatype of newly created prediction column
+-- for prediction type = 'response' and class_values 'TEXT' with NULL
+-- as a valid class_values
+-- Returns: class_value with highest probability
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'SMALLINT', 'prediction column should be smallint type')
+FROM pg_attribute
+WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'estimated_dependent_var';
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index c8b649d..f2375ba 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -385,34 +385,22 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
         self.subject = madlib_keras_predict
         self.classes = ['train', 'boat', 'car', 'airplane']
 
-    def test_get_class_label(self):
-        # test that index in range returns correct class value
-        self.assertEqual(
-                           'boat',
-                            self.subject._get_class_label(self.classes,
-                                                          self.classes.index('boat'))
-                        )
-
-        # test that index is returned if class_values param is None
-        self.assertEqual(
-                            5,
-                            self.subject._get_class_label(None,5)
-                        )
-
-        # test that index too high generates plpy error
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._get_class_label(self.classes, 9)
+    def tearDown(self):
+        self.module_patcher.stop()
 
-        # test that index too low generates plpy error
+    def test_validate_pred_type_invalid_pred_type(self):
         with self.assertRaises(plpy.PLPYException):
-            self.subject._get_class_label(self.classes, -1)
+            self.subject.validate_pred_type('invalid', ['cat', 'dog'])
 
-        # test that non-integer index generates plpy error
+    def test_validate_pred_type_valid_pred_type_invalid_num_class_values(self):
         with self.assertRaises(plpy.PLPYException):
-            self.subject._get_class_label(self.classes, 4.5)
+            self.subject.validate_pred_type('prob', range(1599))
 
-    def tearDown(self):
-        self.module_patcher.stop()
+    def test_validate_pred_type_valid_pred_type_valid_class_values(self):
+        self.subject.validate_pred_type('prob', range(1598))
+        self.subject.validate_pred_type('prob', None)
+        self.subject.validate_pred_type('response', range(1598))
+        self.subject.validate_pred_type('response', None)
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index f9f1fd0..e57f407 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -6,14 +6,15 @@ import random
 from distutils.util import strtobool
 
 from validate_args import _get_table_schema_names
-from validate_args import get_first_schema
-from validate_args import get_cols
 from validate_args import cols_in_tbl_valid
 from validate_args import does_exclude_reserved
 from validate_args import explicit_bool_to_text
+from validate_args import get_cols
+from validate_args import get_first_schema
 from validate_args import input_tbl_valid
 from validate_args import is_var_valid
 from validate_args import output_tbl_valid
+from validate_args import quote_ident
 import plpy
 
 
@@ -220,6 +221,7 @@ def is_valid_psql_type(arg, valid_types):
     """
     if not arg or not valid_types:
         return False
+    arg = arg.lower()
     if ANY_ARRAY <= valid_types:
         return arg.rstrip().endswith('[]')
     if ONLY_ARRAY <= valid_types:
@@ -386,6 +388,92 @@ def py_list_to_sql_string(array, array_type=None, long_format=None):
             qd=quote_delimiter)
 # ------------------------------------------------------------------------
 
+def create_cols_from_array_sql_string(py_list, sql_array_col, colname,
+                                      coltype, has_one_ele,
+                                      module_name='Input Error'):
+    """
+    Create SQL string to convert array of elements into multiple columns.
+    @args:
+        @param: py_list, python list, if None, return sql_array_col as colname.
+                            The py_list can at most have one 'None' element that
+                            is converted to sql 'NULL'
+        @param: sql_array_col, str pointing to a column in table containing
+                               an array.
+        @param: colname, name of output column (can be treated as prefix
+                         to multiple cols if has_one_ele=False)
+        @param: coltype, Type of columns to be created
+        @param: has_one_ele, bool. if True, assumes sql_array_col has
+                    an array of exactly one element, which is treated as in
+                    index to get value from py_list. If False, then a new
+                    column is created for every element in py_list,
+                    whose corresponding values are obtained from sql_array_col.
+    @examples:
+        1) Input:
+                py_list = ['cat', 'dog']
+                sql_array_col = sqlcol
+                colname = prob
+                coltype = TEXT
+                has_one_ele = FALSE
+            Output:
+                CAST(sqlcol[1] AS TEXT) AS prob_cat, CAST(sqlcol[2] AS TEXT) AS prob_dog
+        2) Input:
+                py_list = ['cat', 'dog']
+                sql_array_col = sqlcol
+                colname = estimated_pred
+                coltype = TEXT
+                has_one_ele = TRUE
+            Output:
+                (ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred
+
+    @returns:
+        @param, str, that can be used in a SQL query.
+
+    """
+    _assert(sql_array_col, "{0}: sql_array_col should be a valid string.".
+        format(module_name))
+    _assert(colname, "{0}: colname should be a valid string.".format(
+        module_name))
+    if py_list:
+        _assert(is_valid_psql_type(coltype, BOOLEAN | NUMERIC | TEXT),
+            "{0}: Invalid coltype parameter {1}.".format(
+                module_name, coltype))
+        _assert(py_list.count(None) <= 1,
+                "{0}: Input list should contain at most 1 None element.".
+                    format(module_name))
+        py_list = ['NULL' if ele is None else ele for ele in py_list]
+        if has_one_ele:
+            # Query to choose the value in the first element of
+            # sql_array_col which is the index to access in py_list.
+            # The value from that corresponding index in py_list is
+            # the value of colname.
+            py_list_sql_str = py_list_to_sql_string(py_list, coltype+'[]')
+            select_clause = "({0})[{1}[1]+1]::{2} AS {3}".format(
+                py_list_sql_str, sql_array_col, coltype, colname)
+        else:
+            # Create as many columns as the length of py_list. The
+            # colnames are created based on the elements in py_list,
+            # while the value for these colnames is obtained from
+            # sql_array_col.
+
+            # we cannot call sql quote_ident on the py_list entries because
+            # aliasing does not support quote_ident. Hence calling our
+            # python implementation of quote_ident
+            select_clause = ', '.join(
+                ['CAST({sql_array_col}[{j}] AS {coltype}) AS "{final_colname}"'.
+                    format(j=i + 1,
+                           final_colname=quote_ident("{0}_{1}".
+                               format(colname, str(suffix))).strip(' "'),
+                           sql_array_col=sql_array_col,
+                           coltype=coltype)
+                for i, suffix in enumerate(py_list)
+                ])
+    else:
+        if has_one_ele:
+            select_clause = '{0}[1]+1 AS {1}'.format(sql_array_col, colname)
+        else:
+            select_clause = '{0} AS {1}'.format(sql_array_col, colname)
+    return select_clause
+# ------------------------------------------------------------------------
 
 def _array_to_string(origin):
     """


[madlib] 03/04: DL: Remove reshaping and hard-coded normalizing_const from predict

Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f43765944c92e82eb2bfcf1449b68da75df4c582
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Tue Apr 16 16:53:03 2019 -0700

    DL: Remove reshaping and hard-coded normalizing_const from predict
    
    This commit makes the following changes:
    1. Predict was still reshaping the input, but that's not necessary anymore
    since we assume the input data is already shaped correctly.
    2. Remove hard coded normalizing constant for test data in predict,
    and instead use the normalizing_const from model summary table which is
    essentially the value that was used to normalize training data.
    3. Add dev-check and unit tests for the changes.
    
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
---
 .../modules/deep_learning/madlib_keras.py_in       | 32 +++++-----
 .../modules/deep_learning/madlib_keras.sql_in      |  6 +-
 .../deep_learning/madlib_keras_helper.py_in        | 44 ++++++++------
 .../deep_learning/madlib_keras_predict.py_in       | 21 ++++---
 .../deep_learning/madlib_keras_validator.py_in     |  6 +-
 .../modules/deep_learning/test/madlib_keras.sql_in | 69 ++++++++++++++++++++++
 .../test/unit_tests/test_madlib_keras.py_in        | 24 ++++++++
 7 files changed, 154 insertions(+), 48 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index afe2187..bbbcfb4 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -36,9 +36,10 @@ from keras.models import *
 from keras.optimizers import *
 from keras.regularizers import *
 import madlib_keras_serializer
-from madlib_keras_helper import CLASS_VALUES_CNAME
-from madlib_keras_helper import DEPENDENT_VARTYPE_CNAME
-from madlib_keras_helper import NORMALIZING_CONST_CNAME
+from madlib_keras_helper import CLASS_VALUES_COLNAME
+from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
+from madlib_keras_helper import expand_input_dims
+from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_validator import FitInputValidator
 from madlib_keras_wrapper import *
 from keras_model_arch_table import Format
@@ -196,11 +197,11 @@ def fit(schema_madlib, source_table, model, dependent_varname,
         final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
     class_values, class_values_type = get_col_value_and_type(
-        fit_validator.source_summary_table, CLASS_VALUES_CNAME)
+        fit_validator.source_summary_table, CLASS_VALUES_COLNAME)
     norm_const, norm_const_type = get_col_value_and_type(
-        fit_validator.source_summary_table, NORMALIZING_CONST_CNAME)
+        fit_validator.source_summary_table, NORMALIZING_CONST_COLNAME)
     dep_vartype = plpy.execute("SELECT {0} AS dep FROM {1}".format(
-        DEPENDENT_VARTYPE_CNAME, fit_validator.source_summary_table))[0]['dep']
+        DEPENDENT_VARTYPE_COLNAME, fit_validator.source_summary_table))[0]['dep']
     create_output_summary_table = plpy.prepare("""
         CREATE TABLE {0}_summary AS
         SELECT
@@ -234,8 +235,8 @@ def fit(schema_madlib, source_table, model, dependent_varname,
         $28 AS {1},
         $29 AS {2},
         $30 AS {3}
-        """.format(model, CLASS_VALUES_CNAME, DEPENDENT_VARTYPE_CNAME,
-                   NORMALIZING_CONST_CNAME),
+        """.format(model, CLASS_VALUES_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+                   NORMALIZING_CONST_COLNAME),
                    ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
                     "TIMESTAMP", "TEXT", "TEXT","TEXT",
                     "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
@@ -479,15 +480,12 @@ def internal_keras_evaluate(dependent_var, independent_var, model_architecture,
     with K.tf.device(device_name):
         compile_model(model, compile_params)
 
-    # Since the training data is batched but the validation data isn't, we have
-    # to make sure that the validation data np array has the same no of dimensions
-    # as training data. So we prepend 1 to both x and y np arrays using expand_dims.
-    independent_var = np.array(independent_var)
-    independent_var = np.expand_dims(independent_var, axis=0)
-    independent_var = independent_var.astype('float32')
-
-    dependent_var = np.array(dependent_var)
-    dependent_var = np.expand_dims(dependent_var, axis=0)
+    # Since the training data is batched but the validation data isn't,
+    # we have to make sure that the validation data np array has the same
+    # number of dimensions as training data. So we prepend a dimension to
+    # both x and y np arrays using expand_dims.
+    independent_var = expand_input_dims(independent_var, target_type='float32')
+    dependent_var = expand_input_dims(dependent_var)
 
     with K.tf.device(device_name):
         res = model.evaluate(independent_var, dependent_var)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 8ba24c7..34bf2c2 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -204,7 +204,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
    model_data bytea,
    input_shape integer[],
    compile_params TEXT,
-   is_response BOOLEAN
+   is_response BOOLEAN,
+   normalizing_const DOUBLE PRECISION
 ) RETURNS DOUBLE PRECISION[] AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
@@ -214,7 +215,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
                model_data,
                input_shape,
                compile_params,
-               is_response)
+               is_response,
+               normalizing_const)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index bc5e703..d56a0e3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -17,21 +17,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import numpy as np
 import plpy
 from keras_model_arch_table import Format
 from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 
+# Prepend 1 to np arrays using expand_dims.
+def expand_input_dims(input_data, target_type=None):
+    input_data = np.array(input_data)
+    input_data = np.expand_dims(input_data, axis=0)
+    if target_type:
+        input_data = input_data.astype(target_type)
+    return input_data
+
 # Name of columns in model summary table.
-CLASS_VALUES_CNAME = "class_values"
-NORMALIZING_CONST_CNAME = "normalizing_const"
-DEPENDENT_VARTYPE_CNAME = "dependent_vartype"
-COMPILE_PARAMS_CNAME = "compile_params"
-DEPENDENT_VARNAME_CNAME = "dependent_varname"
-DEPENDENT_VARTYPE_CNAME = "dependent_vartype"
-MODEL_ARCH_TABLE_CNAME = "model_arch_table"
-MODEL_ARCH_ID_CNAME = "model_arch_id"
-MODEL_DATA_CNAME = "model_data"
+CLASS_VALUES_COLNAME = "class_values"
+NORMALIZING_CONST_COLNAME = "normalizing_const"
+COMPILE_PARAMS_COLNAME = "compile_params"
+DEPENDENT_VARNAME_COLNAME = "dependent_varname"
+DEPENDENT_VARTYPE_COLNAME = "dependent_vartype"
+MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
+MODEL_ARCH_ID_COLNAME = "model_arch_id"
+MODEL_DATA_COLNAME = "model_data"
 
 class PredictParamsProcessor:
     def __init__(self, model_table, module_name):
@@ -47,8 +55,8 @@ class PredictParamsProcessor:
             self.model_summary_table))[0]
 
     def _get_model_arch_dict(self):
-        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_CNAME]
-        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_CNAME]
+        model_arch_table = self.model_summary_dict[MODEL_ARCH_TABLE_COLNAME]
+        model_arch_id = self.model_summary_dict[MODEL_ARCH_ID_COLNAME]
         input_tbl_valid(model_arch_table, self.module_name)
         model_arch_query = """
             SELECT {0}
@@ -63,16 +71,16 @@ class PredictParamsProcessor:
         return query_result[0]
 
     def get_class_values(self):
-        return self.model_summary_dict[CLASS_VALUES_CNAME]
+        return self.model_summary_dict[CLASS_VALUES_COLNAME]
 
     def get_compile_params(self):
-        return self.model_summary_dict[COMPILE_PARAMS_CNAME]
+        return self.model_summary_dict[COMPILE_PARAMS_COLNAME]
 
     def get_dependent_varname(self):
-        return self.model_summary_dict[DEPENDENT_VARNAME_CNAME]
+        return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME]
 
     def get_dependent_vartype(self):
-        return self.model_summary_dict[DEPENDENT_VARTYPE_CNAME]
+        return self.model_summary_dict[DEPENDENT_VARTYPE_COLNAME]
 
     def get_model_arch(self):
         return self.model_arch_dict[Format.MODEL_ARCH]
@@ -80,8 +88,8 @@ class PredictParamsProcessor:
     def get_model_data(self):
         return plpy.execute("""
                 SELECT {0} FROM {1}
-            """.format(MODEL_DATA_CNAME, self.model_table)
-                            )[0][MODEL_DATA_CNAME]
+            """.format(MODEL_DATA_COLNAME, self.model_table)
+                            )[0][MODEL_DATA_COLNAME]
 
     def get_normalizing_const(self):
-        return self.model_summary_dict[NORMALIZING_CONST_CNAME]
+        return self.model_summary_dict[NORMALIZING_CONST_COLNAME]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 34b26c3..3108be5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -27,8 +27,9 @@ from keras.models import *
 from keras.optimizers import *
 import numpy as np
 
+from madlib_keras_helper import expand_input_dims
 from madlib_keras_helper import PredictParamsProcessor
-from madlib_keras_helper import MODEL_DATA_CNAME
+from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_wrapper import compile_and_set_weights
 from utilities.model_arch_info import get_input_shape
 from utilities.utilities import add_postfix
@@ -64,7 +65,8 @@ def predict(schema_madlib, model_table, test_table, id_col,
     dependent_vartype = param_proc.get_dependent_vartype()
     model_data = param_proc.get_model_data()
     model_arch = param_proc.get_model_arch()
-
+    normalizing_const = param_proc.get_normalizing_const()
+    # TODO: Validate input shape as part of MADLIB-1312
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
 
@@ -102,23 +104,26 @@ def predict(schema_madlib, model_table, test_table, id_col,
                         {0},
                         ARRAY{input_shape},
                         {compile_params},
-                        {is_response})
+                        {is_response},
+                        {normalizing_const})
                    ) AS {intermediate_col}
         FROM {test_table}, {model_table}
         ) q
-        """.format(MODEL_DATA_CNAME, **locals()))
+        """.format(MODEL_DATA_COLNAME, **locals()))
 
 def internal_keras_predict(x_test, model_arch, model_data, input_shape,
-                           compile_params, is_response):
+                           compile_params, is_response, normalizing_const):
     model = model_from_json(model_arch)
     device_name = '/cpu:0'
     os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
     model_shapes = madlib_keras_serializer.get_model_shapes(model)
     compile_and_set_weights(model, compile_params, device_name,
                             model_data, model_shapes)
-
-    x_test = np.array(x_test).reshape(1, *input_shape)
-    x_test /= 255
+    # Since the test data isn't mini-batched,
+    # we have to make sure that the test data np array has the same
+    # number of dimensions as input_shape. So we add a dimension to x.
+    x_test = expand_input_dims(x_test, target_type='float32')
+    x_test /= normalizing_const
     if is_response:
         proba_argmax = model.predict_classes(x_test)
         # proba_argmax is a list with exactly one element in it. That element
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 481fb1b..606461d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -17,7 +17,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from madlib_keras_helper import CLASS_VALUES_CNAME
+from madlib_keras_helper import CLASS_VALUES_COLNAME
 from utilities.minibatch_validation import validate_dependent_var_for_minibatch
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
@@ -72,11 +72,11 @@ class FitInputValidator:
         input_tbl_valid(self.source_table, self.module_name)
         input_tbl_valid(self.source_summary_table, self.module_name)
         _assert(is_var_valid(
-            self.source_summary_table, CLASS_VALUES_CNAME),
+            self.source_summary_table, CLASS_VALUES_COLNAME),
                 "model_keras error: invalid class_values varname "
                 "('{class_values}') for source_summary_table "
                 "({source_summary_table}).".format(
-                    class_values=CLASS_VALUES_CNAME,
+                    class_values=CLASS_VALUES_COLNAME,
                     source_summary_table=self.source_summary_table))
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index a259630..08ac9cb 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -494,3 +494,72 @@ SELECT assert(UPPER(atttypid::regtype::TEXT) =
     'SMALLINT', 'prediction column should be smallint type')
 FROM pg_attribute
 WHERE attrelid = 'cifar10_predict'::regclass AND attname = 'estimated_dependent_var';
+
+-- Test case with a different input shape (3, 32, 32) instead of (32, 32, 3).
+-- Create a new table with image shape 3, 32, 32
+drop table if exists cifar_10_sample_test_shape;
+create table cifar_10_sample_test_shape(id INTEGER, y SMALLINT, x  REAL[] );
+copy cifar_10_sample_test_shape from stdin delimiter '|';
+1|0|{{{248,248,250,245,245,246,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245},{247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245},{245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247},{248,248,250,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247,245,245,247 [...]
+\.
+
+DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched;
+DROP TABLE IF EXISTS cifar_10_sample_test_shape_batched_summary;
+SELECT minibatch_preprocessor_dl('cifar_10_sample_test_shape','cifar_10_sample_test_shape_batched','y','x', NULL, 255, 3);
+
+-- Change model_arch to reflect channels_first
+DROP TABLE IF EXISTS model_arch;
+SELECT load_keras_model('model_arch',
+  $${
+  "class_name": "Sequential",
+  "keras_version": "2.1.6",
+  "config": [{
+    "class_name": "Conv2D", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}},
+    "name": "conv2d_1",
+    "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null,
+    "dtype": "float32", "activation": "relu", "trainable": true,
+    "data_format": "channels_first", "filters": 32, "padding": "valid",
+    "strides": [1, 1], "dilation_rate": [1, 1], "kernel_regularizer": null,
+    "bias_initializer": {"class_name": "Zeros", "config": {}},
+    "batch_input_shape": [null, 3, 32, 32], "use_bias": true,
+    "activity_regularizer": null, "kernel_size": [3, 3]}},
+    {"class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", "trainable": true, "data_format": "channels_first", "pool_size": [2, 2], "padding": "valid", "strides": [2, 2]}},
+    {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, "trainable": true, "seed": null, "name": "dropout_1"}},
+    {"class_name": "Flatten", "config": {"trainable": true, "name": "flatten_1", "data_format": "channels_first"}},
+    {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null, "bias_constraint": null, "activation": "softmax", "trainable": true, "kernel_regularizer": null, "bias_initializer":
+    {"class_name": "Zeros", "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}
+    }], "backend": "tensorflow"}$$);
+
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_test_shape_batched',
+    'keras_saved_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE);
+
+-- Predict with correctly shaped data, must go thru.
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_test_shape',
+    'id',
+    'x',
+    'cifar10_predict',
+    'prob');
+
+-- Prediction with incorrectly shaped data must error out.
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
+        'keras_saved_out',
+        'cifar_10_sample',
+        'id',
+        'x',
+        'cifar10_predict',
+        'prob');$TRAP$) = 1,
+    'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should have failed.');
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index f2375ba..533e347 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -402,6 +402,30 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
         self.subject.validate_pred_type('response', range(1598))
         self.subject.validate_pred_type('response', None)
 
+class MadlibKerasHelperTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import madlib_keras_helper
+        self.subject = madlib_keras_helper
+        self.input_data = [32, 32, 3]
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_expand_input_dims(self):
+        self.assertEqual(np.array(self.input_data).shape, (3,))
+        res = self.subject.expand_input_dims(self.input_data)
+        self.assertEqual(res.shape, (1, 3))
+
 if __name__ == '__main__':
     unittest.main()
 # ---------------------------------------------------------------------