You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by kh...@apache.org on 2020/05/28 00:25:22 UTC

[madlib] branch master updated (42d049c -> 93cfa56)

This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


    from 42d049c  DBSCAN: Fix predict on Greenplum (#499)
     new 27f8ac9  DL: Add object table info in load MST table utility
     new 93cfa56  update user docs for new object_table param

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../madlib_keras_custom_function.py_in             |   1 -
 .../madlib_keras_model_selection.py_in             |  23 ++-
 .../madlib_keras_model_selection.sql_in            |  34 +++-
 .../deep_learning/madlib_keras_validator.py_in     |  38 +++-
 .../madlib_keras_custom_function.setup.sql_in}     |  41 +++--
 .../test/madlib_keras_custom_function.sql_in       |  25 +--
 .../test/madlib_keras_model_selection.sql_in       |  42 +++++
 .../test_madlib_keras_model_selection_table.py_in  | 194 +++++++++++++++++++++
 8 files changed, 347 insertions(+), 51 deletions(-)
 copy src/ports/postgres/modules/{sample/test/stratified_sample.ic.sql_in => deep_learning/test/madlib_keras_custom_function.setup.sql_in} (63%)


[madlib] 01/02: DL: Add object table info in load MST table utility

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 27f8ac96afc19d80ab4eb6a034d3b3ac29f1011f
Author: Ekta Khanna <ek...@pivotal.io>
AuthorDate: Tue May 19 11:53:48 2020 -0700

    DL: Add object table info in load MST table utility
    
    This commit adds an optional param `object_table` (storing keras custom
    function objects) to `load_model_selection_table()`. This object table
    (if specified) is added to the summary table named
    `<model_selection_table>_summary`, which can be passed to the
    fit/evaluate functions.
---
 .../madlib_keras_custom_function.py_in             |   1 -
 .../madlib_keras_model_selection.py_in             |  23 ++-
 .../madlib_keras_model_selection.sql_in            |  13 +-
 .../deep_learning/madlib_keras_validator.py_in     |  38 +++-
 .../test/madlib_keras_custom_function.setup.sql_in |  41 +++++
 .../test/madlib_keras_custom_function.sql_in       |  25 +--
 .../test/madlib_keras_model_selection.sql_in       |  42 +++++
 .../test_madlib_keras_model_selection_table.py_in  | 194 +++++++++++++++++++++
 8 files changed, 349 insertions(+), 28 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 246c72d..9dcefed 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -18,7 +18,6 @@
 
 import dill
 import plpy
-from plpy import spiexceptions
 from utilities.control import MinWarning
 from utilities.utilities import _assert
 from utilities.utilities import get_col_name_type_sql_string
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
index f3b02c7..46267f0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
@@ -28,6 +28,7 @@ class ModelSelectionSchema:
     MST_KEY = 'mst_key'
     MODEL_ID = ModelArchSchema.MODEL_ID
     MODEL_ARCH_TABLE = 'model_arch_table'
+    OBJECT_TABLE = 'object_table'
     COMPILE_PARAMS = 'compile_params'
     FIT_PARAMS = 'fit_params'
     col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
@@ -55,6 +56,7 @@ class MstLoader():
                  model_id_list,
                  compile_params_list,
                  fit_params_list,
+                 object_table=None,
                  **kwargs):
 
         self.model_arch_table = model_arch_table
@@ -62,13 +64,15 @@ class MstLoader():
         self.model_selection_summary_table = add_postfix(
             model_selection_table, "_summary")
         self.model_id_list = sorted(list(set(model_id_list)))
+        self.object_table = object_table
         MstLoaderInputValidator(
             model_arch_table=self.model_arch_table,
             model_selection_table=self.model_selection_table,
             model_selection_summary_table=self.model_selection_summary_table,
             model_id_list=self.model_id_list,
             compile_params_list=compile_params_list,
-            fit_params_list=fit_params_list
+            fit_params_list=fit_params_list,
+            object_table=object_table
         )
         self.compile_params_list = self.params_preprocessed(
             compile_params_list)
@@ -148,10 +152,12 @@ class MstLoader():
         """
         create_query = """
                         CREATE TABLE {self.model_selection_summary_table} (
-                            {model_arch_table} VARCHAR
+                            {model_arch_table} VARCHAR,
+                            {object_table} VARCHAR
                         );
                        """.format(self=self,
-                                  model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE)
+                                  model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
+                                  object_table=ModelSelectionSchema.OBJECT_TABLE)
         with MinWarning('warning'):
             plpy.execute(create_query)
 
@@ -179,14 +185,21 @@ class MstLoader():
                                       fit_params_col=ModelSelectionSchema.FIT_PARAMS,
                                       **locals())
             plpy.execute(insert_query)
+        if self.object_table is None:
+            object_table = 'NULL::VARCHAR'
+        else:
+            object_table = '$${0}$$'.format(self.object_table)
         insert_summary_query = """
                         INSERT INTO
                             {self.model_selection_summary_table}(
-                                {model_arch_table_name}
+                                {model_arch_table_name},
+                                {object_table_name}
                         )
                         VALUES (
-                            $${self.model_arch_table}$$
+                            $${self.model_arch_table}$$,
+                            {object_table}
                         )
                        """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
+                                  object_table_name=ModelSelectionSchema.OBJECT_TABLE,
                                   **locals())
         plpy.execute(insert_summary_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
index c15757c..7903e7f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
@@ -426,7 +426,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table(
     model_selection_table   VARCHAR,
     model_id_list           INTEGER[],
     compile_params_list     VARCHAR[],
-    fit_params_list         VARCHAR[]
+    fit_params_list         VARCHAR[],
+    object_table            VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_model_selection')
     with AOControl(False):
@@ -435,3 +436,13 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table(
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table(
+    model_arch_table        VARCHAR,
+    model_selection_table   VARCHAR,
+    model_id_list           INTEGER[],
+    compile_params_list     VARCHAR[],
+    fit_params_list         VARCHAR[]
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.load_model_selection_table($1, $2, $3, $4, $5, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 11730cf..a364a9e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -20,6 +20,7 @@
 import plpy
 from keras_model_arch_table import ModelArchSchema
 from model_arch_info import get_num_classes
+from madlib_keras_custom_function import CustomFunctionSchema
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
@@ -45,6 +46,8 @@ from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 from madlib_keras_wrapper import parse_and_validate_fit_params
 from madlib_keras_wrapper import parse_and_validate_compile_params
+import keras.losses as losses
+import keras.metrics as metrics
 
 class InputValidator:
     @staticmethod
@@ -443,7 +446,8 @@ class MstLoaderInputValidator():
                  model_selection_summary_table,
                  model_id_list,
                  compile_params_list,
-                 fit_params_list
+                 fit_params_list,
+                 object_table
                  ):
         self.model_arch_table = model_arch_table
         self.model_selection_table = model_selection_table
@@ -451,6 +455,7 @@ class MstLoaderInputValidator():
         self.model_id_list = model_id_list
         self.compile_params_list = compile_params_list
         self.fit_params_list = fit_params_list
+        self.object_table = object_table
         self.module_name = 'load_model_selection_table'
         self._validate_input_args()
 
@@ -489,9 +494,36 @@ class MstLoaderInputValidator():
                     """.format(fit_params, str(e)))
         if not self.compile_params_list:
             plpy.error( "compile_params_list cannot be NULL")
+        custom_fn_name = []
+        ## Initialize builtin loss/metrics functions
+        builtin_losses = dir(losses)
+        builtin_metrics = dir(metrics)
+        # Default metrics, since it is not part of the builtin metrics list
+        builtin_metrics.append('accuracy')
+        if self.object_table is not None:
+            res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME,
+                                                            self.object_table))
+            for r in res:
+                custom_fn_name.append(r[CustomFunctionSchema.FN_NAME])
         for compile_params in self.compile_params_list:
             try:
-                res = parse_and_validate_compile_params(compile_params)
+                _, _, res = parse_and_validate_compile_params(compile_params)
+                # Validating if loss/metrics function called in compile_params
+                # is either defined in object table or is a built_in keras
+                # loss/metrics function
+                error_suffix = "but input object table missing!"
+                if self.object_table is not None:
+                    error_suffix = "is not defined in object table '{0}'!".format(self.object_table)
+
+                _assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses,
+                        "custom function '{0}' used in compile params "\
+                        "{1}".format(res['loss'], error_suffix))
+                if 'metrics' in res:
+                    _assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0
+                            or len(set(res['metrics']).intersection(builtin_metrics)) > 0),
+                            "custom function '{0}' used in compile params " \
+                            "{1}".format(res['metrics'], error_suffix))
+
             except Exception as e:
                 plpy.error(
                     """Compile param check failed for: {0} \n
@@ -500,6 +532,8 @@ class MstLoaderInputValidator():
 
     def _validate_input_output_tables(self):
         input_tbl_valid(self.model_arch_table, self.module_name)
+        if self.object_table is not None:
+            input_tbl_valid(self.object_table, self.module_name)
         output_tbl_valid(self.model_selection_table, self.module_name)
         output_tbl_valid(self.model_selection_summary_table, self.module_name)
 
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
new file mode 100644
index 0000000..671cf07
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
@@ -0,0 +1,41 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ** ---------------------------------------------------------------------*/
+
+---- utility for creating valid dill objects ----
+CREATE OR REPLACE FUNCTION custom_function_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_sum_fn(a, b):
+	return a+b
+
+pb=dill.dumps(test_sum_fn)
+return pb
+$$ language plpythonu;
+
+CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int)
+RETURNS INTEGER AS
+$$
+import dill
+obj=dill.loads(pb)
+res=obj(arg1, arg2)
+return res
+$$ language plpythonu;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index 74f6ba2..82d5e97 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -23,25 +23,12 @@
  * Test load custom function helper functions
  * -------------------------------------------------------------------------- */
 
-CREATE OR REPLACE FUNCTION custom_function_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_sum_fn(a, b):
-	return a+b
-
-pb=dill.dumps(test_sum_fn)
-return pb
-$$ language plpythonu;
-
-CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int)
-RETURNS INTEGER AS
-$$
-import dill
-obj=dill.loads(pb)
-res=obj(arg1, arg2)
-return res
-$$ language plpythonu;
+m4_include(`SQLCommon.m4')
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
 
 /* Test successful table creation where no table exists */
 DROP TABLE IF EXISTS test_custom_function_table;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index e1dbe0c..fa90c86 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -26,6 +26,11 @@ m4_include(`SQLCommon.m4')
              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
 )
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
 -- MST table generation tests
 -- Valid inputs should pass and yield 6 msts in the table
 DROP TABLE IF EXISTS mst_table, mst_table_summary;
@@ -215,6 +220,43 @@ SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'M
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');
 
+-- Test for object table
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT assert(MADLIB_SCHEMA.trap_error($MAD$
+  SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_object_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ],
+    'test_custom_function_table')
+$MAD$) = 1, 'Object table does not exist!');
+SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
+
+DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_object_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ],
+    'test_custom_function_table'
+);
+
+SELECT assert(
+        object_table = 'test_custom_function_table',
+        'Keras Fit Multiple Output Summary Validation failed when user passes in object_table. Actual:' || __to_char(summary))
+FROM (SELECT * FROM mst_object_table_summary) summary;
+
 -- Test when number of configs(3) equals number of segments(3)
 DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT setseed(0);
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
index 57e08a5..b911992 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
@@ -53,6 +53,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
         self.subject = self.module.MstLoader
         self.model_selection_table = 'mst_table'
         self.model_arch_table = 'model_arch_library'
+        self.object_table = 'custom_function_table'
         self.model_id_list = [1]
         self.compile_params_list = [
             """
@@ -99,6 +100,20 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
                 self.fit_params_list
             )
 
+    def test_invalid_input_args_optional_param(self):
+        self.module.MstLoaderInputValidator \
+            ._validate_input_args \
+            .side_effect = plpy.PLPYException('Invalid input args')
+        with self.assertRaises(plpy.PLPYException):
+            generate_mst = self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_id_list,
+                self.compile_params_list,
+                self.fit_params_list,
+                "invalid_table"
+            )
+
     def test_duplicate_params(self):
         self.model_id_list = [1, 1, 2]
         self.compile_params_list = [
@@ -135,6 +150,185 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
+class MstLoaderInputValidatorTestCase(unittest.TestCase):
+    def setUp(self):
+        # The side effects of this class(writing to the output table) are not
+        # tested here. They are tested in dev-check.
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import deep_learning.madlib_keras_validator
+        self.module = deep_learning.madlib_keras_validator
+
+        self.subject = self.module.MstLoaderInputValidator
+        self.model_selection_table = 'mst_table'
+        self.model_arch_table = 'model_arch_library'
+        self.model_arch_summary_table = 'model_arch_library_summary'
+        self.object_table = 'custom_function_table'
+        self.model_id_list = [1]
+        self.compile_params_list = [
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """,
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.01)',
+                metrics=['accuracy']
+            """,
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.001)',
+                metrics=['accuracy']
+            """
+        ]
+        self.fit_params_list = [
+            "batch_size=5,epochs=1",
+            "batch_size=10,epochs=1"
+        ]
+
+    def test_validate_compile_params_no_custom_fn_table(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list,
+            self.fit_params_list,
+            None
+        )
+
+    def test_test_validate_compile_params_custom_fn_table(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                              {'name': 'custom_fn2'}]]
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list,
+            self.fit_params_list,
+            self.object_table
+        )
+
+    def test_test_validate_compile_params_valid_custom_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_valid_custom_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list_valid_custom_fn,
+            self.fit_params_list,
+            self.object_table
+        )
+
+    def test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_valid_custom_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_valid_custom_fn,
+                self.fit_params_list,
+                None
+            )
+        self.assertIn("object table missing", str(error.exception).lower())
+
+    def test_test_validate_compile_params_missing_loss_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_invalid_loss_fn = [
+            """
+                loss='invalid_loss',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_invalid_loss_fn,
+                self.fit_params_list,
+                self.object_table
+            )
+        self.assertIn("invalid_loss", str(error.exception).lower())
+
+    def test_test_validate_compile_params_missing_metric_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+
+        self.compile_params_list_invalid_metric_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['invalid_metrics']
+            """
+        ]
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_invalid_metric_fn,
+                self.fit_params_list,
+                self.object_table
+            )
+        self.assertIn("invalid_metrics", str(error.exception).lower())
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
 if __name__ == '__main__':
     unittest.main()
 # ---------------------------------------------------------------------


[madlib] 02/02: update user docs for new object_table param

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 93cfa565088c61cf741111ebf56e5eab9f012577
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Wed May 27 12:50:30 2020 -0700

    update user docs for new object_table param
---
 .../madlib_keras_model_selection.sql_in             | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
index 7903e7f..d6c10e3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
@@ -59,7 +59,8 @@ load_model_selection_table(
     model_selection_table,
     model_id_list,
     compile_params_list,
-    fit_params_list
+    fit_params_list,
+    object_table
     )
 </pre>
 
@@ -87,7 +88,10 @@ load_model_selection_table(
   <dt>compile_params_list</dt>
   <dd>VARCHAR[]. Array of compile parameters to be tested.  Each element
   of the array should consist of a string of compile parameters
-  exactly as it is to be passed to Keras.
+  exactly as it is to be passed to Keras. For custom loss functions or custom metrics,
+  list the custom function name in the usual way, and also provide the name of the
+  table where the serialized objects reside in the parameter 'object_table'
+  below.
   </dd>
 
   <dt>fit_params_list</dt>
@@ -96,6 +100,12 @@ load_model_selection_table(
   exactly as it is to be passed to Keras.
   </dd>
 
+  <dt>object_table (optional)</dt>
+  <dd>VARCHAR, default: NULL. Name of the table containing Python objects in the case that
+  custom loss functions or custom metrics are specified in the
+  parameter 'compile_params_list'.
+  </dd>
+
 </dl>
 
 <b>Output table</b>
@@ -133,6 +143,13 @@ load_model_selection_table(
         model architecture IDs.
         </td>
       </tr>
+      <tr>
+        <th>object_table</th>
+        <td>VARCHAR. Name of the object table containing the serialized
+        Python objects for custom loss functions and custom metrics.
+        If there are none, this field will be blank.
+        </td>
+      </tr>
     </table>
 </br>