You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by kh...@apache.org on 2020/10/21 00:47:17 UTC

[madlib] branch master updated (49262a5 -> 181c28e)

This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


    from 49262a5  DL: Implement caching for fit_multiple_model
     new 918256d  DL: Add a helper function to load custom top n accuracy functions
     new 181c28e  update user docs and examples for custom functions

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../madlib_keras_custom_function.py_in             |  85 ++++++++++-
 .../madlib_keras_custom_function.sql_in            | 164 ++++++++++++++++++---
 .../madlib_keras_model_selection.sql_in            | 116 +++++++++++++--
 .../deep_learning/madlib_keras_wrapper.py_in       |  10 +-
 .../test/madlib_keras_custom_function.sql_in       | 105 +++++++------
 .../test/madlib_keras_model_averaging_e2e.sql_in   |  10 +-
 .../test/madlib_keras_model_selection_e2e.sql_in   |  10 +-
 .../test/unit_tests/test_madlib_keras.py_in        |   7 +
 8 files changed, 417 insertions(+), 90 deletions(-)


[madlib] 02/02: update user docs and examples for custom functions

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 181c28e726e72b7624195473d0618b9f9e7d3c9b
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Thu Oct 8 14:48:08 2020 -0700

    update user docs and examples for custom functions
    
    Also, fix format error in user docs
---
 .../madlib_keras_custom_function.sql_in            |  37 ++++---
 .../madlib_keras_model_selection.sql_in            | 116 +++++++++++++++++++--
 2 files changed, 131 insertions(+), 22 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index acdaa28..bb9864d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -38,7 +38,7 @@ Interface and implementation are subject to change. </em>
 <div class="toc"><b>Contents</b><ul>
 <li class="level1"><a href="#load_function">Load Function</a></li>
 <li class="level1"><a href="#delete_function">Delete Function</a></li>
-<li class="level1"><a href="#top_n_function">Top n Function</a></li>
+<li class="level1"><a href="#top_k_function">Top k Accuracy Function</a></li>
 <li class="level1"><a href="#example">Examples</a></li>
 <li class="level1"><a href="#literature">Literature</a></li>
 <li class="level1"><a href="#related">Related Topics</a></li>
@@ -52,6 +52,11 @@ The functions to be loaded must be in the form of serialized Python objects
 created using Dill, which extends Python's pickle module to the majority
 of the built-in Python types [1].
 
+Custom functions are also used to return top k categorical accuracy rate
+in the case that you want a different k value than the default from Keras.
+This module includes a helper function to create the custom function
+automatically for a specified k. 
+
 There is also a utility function to delete a function
 from the table.
 
@@ -150,10 +155,13 @@ delete_custom_function(
   </dd>
 </dl>
 
-@anchor top_n_function
-@par Top n Function
+@anchor top_k_function
+@par Top k Accuracy Function
 
-Load a top n function with a specific n to the custom functions table.
+Create and load a custom function for a specific k into the custom functions table.
+The Keras accuracy parameter 'top_k_categorical_accuracy' returns top 5 accuracy by default [2].
+If you want a different top k value, use this helper function to create a custom
+Python function to compute the top k accuracy that you specify.
 
 <pre class="syntax">
 load_top_k_accuracy_function(
@@ -170,7 +178,7 @@ load_top_k_accuracy_function(
   </dd>
 
   <dt>k</dt>
-  <dd>INTEGER. k value for the top k accuracy function.
+  <dd>INTEGER. k value for the top k accuracy that you want.
   </dd>
 
 </dl>
@@ -187,12 +195,12 @@ load_top_k_accuracy_function(
       <tr>
         <th>name</th>
         <td>TEXT PRIMARY KEY. Name of the object.
-        Generated with the following pattern: (sparse_,)top_(n)_accuracy.
+        Generated with the following pattern: top_(k)_accuracy.
         </td>
       </tr>
       <tr>
         <th>description</th>
-        <td>TEXT. Description of the object (free text).
+        <td>TEXT. Description of the object.
         </td>
       </tr>
       <tr>
@@ -233,7 +241,7 @@ conn.commit()
 </pre>
 List table to see objects:
 <pre class="example">
-SELECT id, name, description FROM test_custom_function_table ORDER BY id;
+SELECT id, name, description FROM custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
  id |     name      |      description
@@ -292,23 +300,28 @@ SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
 </pre>
 If all objects are deleted from the table using this function, the table itself will be dropped.
 </pre>
-Load top 3 accuracy function:
+-# Load top 3 accuracy function followed by a top 10 accuracy function:
 <pre class="example">
 DROP TABLE IF EXISTS custom_function_table;
 SELECT madlib.load_top_k_accuracy_function('custom_function_table',
                                            3);
+SELECT madlib.load_top_k_accuracy_function('custom_function_table',
+                                           10);
 SELECT id, name, description FROM custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id |      name      |      description
-----+----------------+------------------------
-  1 | top_3_accuracy | returns top_3_accuracy
+ id |      name       |       description       
+----+-----------------+-------------------------
+  1 | top_3_accuracy  | returns top_3_accuracy
+  2 | top_10_accuracy | returns top_10_accuracy
 </pre>
 @anchor literature
 @literature
 
 [1] Dill https://pypi.org/project/dill/
 
+[2] https://keras.io/api/metrics/accuracy_metrics/#topkcategoricalaccuracy-class
+
 @anchor related
 @par Related Topics
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
index fd18edb..870dd18 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
@@ -81,7 +81,8 @@ generate_model_configs(
   </dd>
 
   <dt>model_selection_table</dt>
-  <dd>VARCHAR. Model selection table created by this module.  A summary table
+  <dd>VARCHAR. Model selection table created by this module.  If this table already
+  exists, it will be appended to.  A summary table
   named <model_selection_table>_summary is also created.  Contents of both output
   tables are described below.
   </dd>
@@ -119,10 +120,12 @@ generate_model_configs(
   than regular log-based sampling. 
   
   In the case of grid search, omit the sample type and just put the grid points in the list.
-  For custom loss functions or custom metrics,
-  list the custom function name in the usual way, and provide the name of the
+  For custom loss functions, custom metrics, and custom top k categorical accuracy,
+  list the custom function name and provide the name of the
   table where the serialized Python objects reside using the 
-  parameter 'object_table' below. See the examples section later on this page for more examples. 
+  parameter 'object_table' below. See the examples section later on this page.
+  For more information on custom functions, please
+  see <a href="group__grp__custom__function.html">Load Custom Functions</a>. 
   </dd>
 
   <dt>fit_params_grid</dt>
@@ -139,7 +142,6 @@ generate_model_configs(
     } 
   $$
   </pre>
-  See the examples section later on this page for more examples.
   </dd>
 
   <dt>search_type</dt>
@@ -223,7 +225,7 @@ generate_model_configs(
 @anchor load_mst_table
 @par Load Model Selection Table [Deprecated]
 
-This method is deprecated and replaced by the method 'generate_model_configs()' described above.
+This method is deprecated and replaced by the method 'generate_model_configs' described above.
 
 <pre class="syntax">
 load_model_selection_table(
@@ -247,7 +249,7 @@ load_model_selection_table(
   <dt>model_selection_table</dt>
   <dd>VARCHAR. Model selection table created by this utility.  A summary table
   named <model_selection_table>_summary is also created.  Contents of both output
-  tables are the same as described above for the method 'generate_model_configs()'.
+  tables are the same as described above for the method 'generate_model_configs'.
   </dd>
 
   <dt>model_id_list</dt>
@@ -672,10 +674,104 @@ SELECT * FROM mst_table_manual_summary;
 </pre>
 
 -# Custom loss functions and custom metrics.
-TBD
-
+Let's say we have a table 'custom_function_table' that contains a custom loss
+function called 'my_custom_loss' and a custom accuracy function
+called 'my_custom_accuracy' based
+on <a href="group__grp__custom__function.html">Load Custom Functions.</a>
+Generate the model configurations with:
+<pre class="example">
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT madlib.generate_model_configs(
+                                        'model_arch_library', -- model architecture table
+                                        'mst_table',          -- model selection table output
+                                         ARRAY[1,2],          -- model ids from model architecture table
+                                         $$
+                                            {'loss': ['my_custom_loss'],
+                                             'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.001, 0.01]} ],
+                                             'metrics': ['my_custom_accuracy']}
+                                         $$,                  -- compile_param_grid
+                                         $$
+                                         { 'batch_size': [64, 128],
+                                           'epochs': [10]
+                                         }
+                                         $$,                  -- fit_param_grid
+                                         'grid',              -- search_type
+                                         NULL,                -- num_configs
+                                         NULL,                -- random_state
+                                         'custom_function_table'  -- table with custom functions
+                                         );
+SELECT * FROM mst_table ORDER BY mst_key;
+</pre>
+<pre class="result">
+ mst_key | model_id |                                 compile_params                                  |        fit_params
+---------+----------+---------------------------------------------------------------------------------+--------------------------
+       1 |        1 | optimizer='Adam(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss' | epochs=10,batch_size=64
+       2 |        1 | optimizer='Adam(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss' | epochs=10,batch_size=128
+       3 |        1 | optimizer='SGD(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=64
+       4 |        1 | optimizer='SGD(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=128
+       5 |        1 | optimizer='Adam(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=64
+       6 |        1 | optimizer='Adam(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=128
+       7 |        1 | optimizer='SGD(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'   | epochs=10,batch_size=64
+       8 |        1 | optimizer='SGD(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'   | epochs=10,batch_size=128
+       9 |        2 | optimizer='Adam(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss' | epochs=10,batch_size=64
+      10 |        2 | optimizer='Adam(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss' | epochs=10,batch_size=128
+      11 |        2 | optimizer='SGD(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=64
+      12 |        2 | optimizer='SGD(lr=0.001)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=128
+      13 |        2 | optimizer='Adam(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=64
+      14 |        2 | optimizer='Adam(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'  | epochs=10,batch_size=128
+      15 |        2 | optimizer='SGD(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'   | epochs=10,batch_size=64
+      16 |        2 | optimizer='SGD(lr=0.01)',metrics=['my_custom_accuracy'],loss='my_custom_loss'   | epochs=10,batch_size=128
+(16 rows)
+</pre>
+Similarly, if you created a custom top k categorical accuracy function 'top_3_accuracy'
+in <a href="group__grp__custom__function.html">Load Custom Functions</a>
+you can generate the model configurations as:
+<pre class="example">
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT madlib.generate_model_configs(
+                                        'model_arch_library', -- model architecture table
+                                        'mst_table',          -- model selection table output
+                                         ARRAY[1,2],          -- model ids from model architecture table
+                                         $$
+                                            {'loss': ['categorical_crossentropy'],
+                                             'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.001, 0.01]} ],
+                                             'metrics': ['top_3_accuracy']}
+                                         $$,                  -- compile_param_grid
+                                         $$
+                                         { 'batch_size': [64, 128],
+                                           'epochs': [10]
+                                         }
+                                         $$,                  -- fit_param_grid
+                                         'grid',              -- search_type
+                                         NULL,                -- num_configs
+                                         NULL,                -- random_state
+                                         'custom_function_table'  -- table with custom functions
+                                         );
+SELECT * FROM mst_table ORDER BY mst_key;
+</pre>
+<pre class="result">
+ mst_key | model_id |                                 compile_params                                  |        fit_params
+---------+----------+---------------------------------------------------------------------------------+--------------------------
+       1 |        1 | optimizer='Adam(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy' | epochs=10,batch_size=64
+       2 |        1 | optimizer='Adam(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy' | epochs=10,batch_size=128
+       3 |        1 | optimizer='SGD(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=64
+       4 |        1 | optimizer='SGD(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=128
+       5 |        1 | optimizer='Adam(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=64
+       6 |        1 | optimizer='Adam(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=128
+       7 |        1 | optimizer='SGD(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'   | epochs=10,batch_size=64
+       8 |        1 | optimizer='SGD(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'   | epochs=10,batch_size=128
+       9 |        2 | optimizer='Adam(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy' | epochs=10,batch_size=64
+      10 |        2 | optimizer='Adam(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy' | epochs=10,batch_size=128
+      11 |        2 | optimizer='SGD(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=64
+      12 |        2 | optimizer='SGD(lr=0.001)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=128
+      13 |        2 | optimizer='Adam(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=64
+      14 |        2 | optimizer='Adam(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'  | epochs=10,batch_size=128
+      15 |        2 | optimizer='SGD(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'   | epochs=10,batch_size=64
+      16 |        2 | optimizer='SGD(lr=0.01)',metrics=['top_3_accuracy'],loss='categorical_crossentropy'   | epochs=10,batch_size=128
+(16 rows)
+</pre>
 -# <b>[Deprecated]</b> Load model selection table.  This method is replaced 
-by the 'generate_model_configs()' method described above.
+by the 'generate_model_configs' method described above.
 Select the model(s) from the model
 architecture table that you want to run, along with the compile and
 fit parameters.  Unique combinations will be created:


[madlib] 01/02: DL: Add a helper function to load custom top n accuracy functions

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 918256db8da56e7de690e3770c02fbf8afafb9ad
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Wed Sep 23 19:53:55 2020 +0300

    DL: Add a helper function to load custom top n accuracy functions
    
    JIRA: MADLIB-1452
    
    This commit enables the top_n_accuracy metric. The current parser
    cannot use top_n_accuracy(k=3) format because we don't want to
    run eval for security reasons. Instead, we add a helper function
    so that the user can easily create a custom top_n_accuracy
    function.
---
 .../madlib_keras_custom_function.py_in             |  85 +++++++++++-
 .../madlib_keras_custom_function.sql_in            | 149 ++++++++++++++++++---
 .../deep_learning/madlib_keras_wrapper.py_in       |  10 +-
 .../test/madlib_keras_custom_function.sql_in       | 105 +++++++++------
 .../test/madlib_keras_model_averaging_e2e.sql_in   |  10 +-
 .../test/madlib_keras_model_selection_e2e.sql_in   |  10 +-
 .../test/unit_tests/test_madlib_keras.py_in        |   7 +
 7 files changed, 297 insertions(+), 79 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 23e16f6..e500970 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -60,7 +60,6 @@ def _validate_object(object, **kwargs):
     except Exception as e:
         plpy.error("{0}: Invalid function object".format(module_name, e))
 
-@MinWarning("error")
 def load_custom_function(object_table, object, name, description=None, **kwargs):
     object_table = quote_ident(object_table)
     _validate_object(object)
@@ -74,7 +73,19 @@ def load_custom_function(object_table, object, name, description=None, **kwargs)
             .format(object_table, col_defs, CustomFunctionSchema.FN_NAME)
 
         plpy.execute(sql, 0)
-        plpy.info("{0}: Created new custom function table {1}." \
+        # Using plpy.notice here as this function can be called:
+        # 1. Directly by the user, we do want to display to the user
+        #    if we create a new table or later the function name that
+        #    is added to the table
+        # 2. From load_top_k_accuracy_function, since plpy.info
+        #    displays the query context when called from the function
+        #    there is a very verbose output and cannot be suppressed with
+        #    MinWarning decorator as INFO is always displayed irrespective
+        #    of what the decorator sets the client_min_messages to.
+        #    Therefore, instead we print this information as a NOTICE
+        #    when called directly by the user and suppress it by setting
+        #    MinWarning decorator to 'error' level in the calling function.
+        plpy.notice("{0}: Created new custom function table {1}." \
                   .format(module_name, object_table))
     else:
         missing_cols = columns_missing_from_table(object_table,
@@ -98,10 +109,9 @@ def load_custom_function(object_table, object, name, description=None, **kwargs)
             plpy.error("Function '{0}' already exists in {1}".format(name, object_table))
         plpy.error(e)
 
-    plpy.info("{0}: Added function {1} to {2} table".
+    plpy.notice("{0}: Added function {1} to {2} table".
               format(module_name, name, object_table))
 
-@MinWarning("error")
 def delete_custom_function(object_table, id=None, name=None, **kwargs):
     object_table = quote_ident(object_table)
     input_tbl_valid(object_table, "Keras Custom Funtion")
@@ -126,7 +136,7 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs):
     res = plpy.execute(sql, 0)
 
     if res.nrows() > 0:
-        plpy.info("{0}: Object id {1} has been deleted from {2}.".
+        plpy.notice("{0}: Object id {1} has been deleted from {2}.".
                   format(module_name, id, object_table))
     else:
         plpy.error("{0}: Object id {1} not found".format(module_name, id))
@@ -134,7 +144,7 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs):
     sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table)
     res = plpy.execute(sql, 0)
     if not res:
-        plpy.info("{0}: Dropping empty custom keras function table " \
+        plpy.notice("{0}: Dropping empty custom keras function table " \
                   "table {1}".format(module_name, object_table))
         sql = "DROP TABLE {0}".format(object_table)
         plpy.execute(sql, 0)
@@ -146,6 +156,27 @@ def update_builtin_metrics(builtin_metrics):
     builtin_metrics.append('ce')
     return builtin_metrics
 
+@MinWarning("error")
+def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs):
+
+    object_table = quote_ident(object_table)
+    _assert(k > 0,
+        "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name))
+    fn_name = "top_{k}_accuracy".format(**locals())
+
+    sql = """
+        SELECT  {schema_madlib}.load_custom_function(\'{object_table}\',
+                {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'),
+                \'{fn_name}\',
+                \'returns {fn_name}\');
+        """.format(**locals())
+    plpy.execute(sql)
+    # As this function allocates the name for the top_k_accuracy function,
+    # printing it out here so the user doesn't need to lookup for the
+    # newly added custom function name in the object_table
+    plpy.info("{0}: Added function \'{1}\' to \'{2}\' table".
+                format(module_name, fn_name, object_table))
+    return
 
 class KerasCustomFunctionDocumentation:
     @staticmethod
@@ -250,3 +281,45 @@ class KerasCustomFunctionDocumentation:
 
         return KerasCustomFunctionDocumentation._returnHelpMsg(
             schema_madlib, message, summary, usage, method)
+
+    @staticmethod
+    def load_top_k_accuracy_function_help(schema_madlib, message):
+        method = "load_top_k_accuracy_function"
+        summary = """
+        ----------------------------------------------------------------
+                            SUMMARY
+        ----------------------------------------------------------------
+        The user can specify a custom n value for top_n_accuracy metric.
+        If the output table already exists, the custom function specified
+        will be added as a new row into the table. The output table could
+        thus act as a repository of Keras custom functions.
+
+        For more details on function usage:
+        SELECT {schema_madlib}.{method}('usage')
+        """.format(**locals())
+
+        usage = """
+        ---------------------------------------------------------------------------
+                                        USAGE
+        ---------------------------------------------------------------------------
+        SELECT {schema_madlib}.{method}(
+            object_table,       --  VARCHAR. Output table to load custom function.
+            k                   --  INTEGER. The number of samples for top n accuracy
+        );
+
+
+        ---------------------------------------------------------------------------
+                                        OUTPUT
+        ---------------------------------------------------------------------------
+        The output table produced by load_top_k_accuracy_function contains the following columns:
+
+        'id'                    -- SERIAL. Function ID.
+        'name'                  -- TEXT PRIMARY KEY. unique function name.
+        'description'           -- TEXT. function description.
+        'object'                -- BYTEA. dill pickled function object.
+
+        """.format(**locals())
+
+        return KerasCustomFunctionDocumentation._returnHelpMsg(
+            schema_madlib, message, summary, usage, method)
+    # ---------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index 01523f3..acdaa28 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -38,6 +38,7 @@ Interface and implementation are subject to change. </em>
 <div class="toc"><b>Contents</b><ul>
 <li class="level1"><a href="#load_function">Load Function</a></li>
 <li class="level1"><a href="#delete_function">Delete Function</a></li>
+<li class="level1"><a href="#top_n_function">Top n Function</a></li>
 <li class="level1"><a href="#example">Examples</a></li>
 <li class="level1"><a href="#literature">Literature</a></li>
 <li class="level1"><a href="#related">Related Topics</a></li>
@@ -45,10 +46,10 @@ Interface and implementation are subject to change. </em>
 
 This utility function loads custom Python functions
 into a table for use by deep learning algorithms.
-Custom functions can be useful if, for example, you need loss functions 
+Custom functions can be useful if, for example, you need loss functions
 or metrics that are not built into the standard libraries.
-The functions to be loaded must be in the form of serialized Python objects 
-created using Dill, which extends Python's pickle module to the majority 
+The functions to be loaded must be in the form of serialized Python objects
+created using Dill, which extends Python's pickle module to the majority
 of the built-in Python types [1].
 
 There is also a utility function to delete a function
@@ -69,8 +70,8 @@ load_custom_function(
 <dl class="arglist">
   <dt>object table</dt>
   <dd>VARCHAR. Table to load serialized Python objects.  If this table
-  does not exist, it will be created.  If this table already 
-  exists, a new row is inserted into the existing table. 
+  does not exist, it will be created.  If this table already
+  exists, a new row is inserted into the existing table.
   </dd>
 
   <dt>object</dt>
@@ -149,10 +150,63 @@ delete_custom_function(
   </dd>
 </dl>
 
+@anchor top_n_function
+@par Top n Function
+
+Load a top n function with a specific n to the custom functions table.
+
+<pre class="syntax">
+load_top_k_accuracy_function(
+    object table,
+    k
+    )
+</pre>
+\b Arguments
+<dl class="arglist">
+  <dt>object table</dt>
+  <dd>VARCHAR. Table to load serialized Python objects.  If this table
+  does not exist, it will be created.  If this table already
+  exists, a new row is inserted into the existing table.
+  </dd>
+
+  <dt>k</dt>
+  <dd>INTEGER. k value for the top k accuracy function.
+  </dd>
+
+</dl>
+
+<b>Output table</b>
+<br>
+    The output table contains the following columns:
+    <table class="output">
+      <tr>
+        <th>id</th>
+        <td>SERIAL. Object ID.
+        </td>
+      </tr>
+      <tr>
+        <th>name</th>
+        <td>TEXT PRIMARY KEY. Name of the object.
+        Generated with the following pattern: (sparse_,)top_(n)_accuracy.
+        </td>
+      </tr>
+      <tr>
+        <th>description</th>
+        <td>TEXT. Description of the object (free text).
+        </td>
+      </tr>
+      <tr>
+        <th>object</th>
+        <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type.
+        </td>
+      </tr>
+    </table>
+</br>
+
 @anchor example
 @par Examples
--# Load object using psycopg2. Psycopg is a PostgreSQL database 
-adapter for the Python programming language.  Note need to use the 
+-# Load object using psycopg2. Psycopg is a PostgreSQL database
+adapter for the Python programming language.  Note need to use the
 psycopg2.Binary() method to pass as bytes.
 <pre class="example">
 \# import database connector psycopg2 and create connection cursor
@@ -163,12 +217,12 @@ cur = conn.cursor()
 import dill
 \# custom loss
 def squared_error(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.square(y_pred - y_true)
 pb_squared_error=dill.dumps(squared_error)
 \# custom metric
 def rmse(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
 pb_rmse=dill.dumps(rmse)
 \# call load function
@@ -182,7 +236,7 @@ List table to see objects:
 SELECT id, name, description FROM test_custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id |     name      |      description       
+ id |     name      |      description
 ----+---------------+------------------------
   1 | squared_error | squared error
   2 | rmse          | root mean square error
@@ -194,7 +248,7 @@ RETURNS BYTEA AS
 $$
 import dill
 def squared_error(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.square(y_pred - y_true)
 pb_squared_error=dill.dumps(squared_error)
 return pb_squared_error
@@ -204,7 +258,7 @@ RETURNS BYTEA AS
 $$
 import dill
 def rmse(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
 pb_rmse=dill.dumps(rmse)
 return pb_rmse
@@ -213,13 +267,13 @@ $$ language plpythonu;
 Now call loader:
 <pre class="result">
 DROP TABLE IF EXISTS custom_function_table;
-SELECT madlib.load_custom_function('custom_function_table', 
-                                   custom_function_squared_error(), 
-                                   'squared_error', 
+SELECT madlib.load_custom_function('custom_function_table',
+                                   custom_function_squared_error(),
+                                   'squared_error',
                                    'squared error');
-SELECT madlib.load_custom_function('custom_function_table', 
-                                   custom_function_rmse(), 
-                                   'rmse', 
+SELECT madlib.load_custom_function('custom_function_table',
+                                   custom_function_rmse(),
+                                   'rmse',
                                    'root mean square error');
 </pre>
 -# Delete an object by id:
@@ -228,7 +282,7 @@ SELECT madlib.delete_custom_function( 'custom_function_table', 1);
 SELECT id, name, description FROM custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id | name |      description       
+ id | name |      description
 ----+------+------------------------
   2 | rmse | root mean square error
 </pre>
@@ -237,7 +291,19 @@ Delete an object by name:
 SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
 </pre>
 If all objects are deleted from the table using this function, the table itself will be dropped.
-
+</pre>
+Load top 3 accuracy function:
+<pre class="example">
+DROP TABLE IF EXISTS custom_function_table;
+SELECT madlib.load_top_k_accuracy_function('custom_function_table',
+                                           3);
+SELECT id, name, description FROM custom_function_table ORDER BY id;
+</pre>
+<pre class="result">
+ id |      name      |      description
+----+----------------+------------------------
+  1 | top_3_accuracy | returns top_3_accuracy
+</pre>
 @anchor literature
 @literature
 
@@ -323,3 +389,46 @@ RETURNS VARCHAR AS $$
     return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, '')
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-- Top n accuracy function
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+    object_table            VARCHAR,
+    k                       INTEGER
+) RETURNS VOID AS $$
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function')
+    with AOControl(False):
+        madlib_keras_custom_function.load_top_k_accuracy_function(**globals())
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+    message VARCHAR
+) RETURNS VARCHAR AS $$
+    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, message)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function()
+RETURNS VARCHAR AS $$
+    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, '')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.top_k_categorical_acc_pickled(
+n INTEGER,
+fn_name VARCHAR
+) RETURNS BYTEA AS $$
+    import dill
+    from keras.metrics import top_k_categorical_accuracy
+
+    def fn(Y_true, Y_pred):
+        return top_k_categorical_accuracy(Y_true,
+                                          Y_pred,
+                                          k = n)
+    fn.__name__= fn_name
+    pb=dill.dumps(fn)
+    return pb
+$$ language plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 780de8a..57827c5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -217,6 +217,9 @@ def parse_and_validate_compile_params(str_of_args, additional_params=[]):
         opt_name, opt_args = None, None
 
     _assert('loss' in compile_dict, "loss is a required parameter for compile")
+    unsupported_loss_list = ['sparse_categorical_crossentropy']
+    _assert(compile_dict['loss'] not in unsupported_loss_list,
+            "Loss function {0} is not supported.".format(compile_dict['loss']))
     validate_compile_param_types(compile_dict)
     _validate_metrics(compile_dict)
     return (opt_name, opt_args, compile_dict)
@@ -226,10 +229,10 @@ def _validate_metrics(compile_dict):
             compile_dict['metrics'] is None or
             type(compile_dict['metrics']) is list,
             "wrong input type for compile parameter metrics: multi-output model"
-            "and user defined metrics are not supported yet, please pass a list")
+            "are not supported yet, please pass a list")
     if 'metrics' in compile_dict and compile_dict['metrics']:
         unsupported_metrics_list = ['sparse_categorical_accuracy',
-                                    'sparse_categorical_crossentropy', 'top_k_categorical_accuracy',
+                                    'sparse_categorical_crossentropy',
                                     'sparse_top_k_categorical_accuracy']
         _assert(len(compile_dict['metrics']) == 1,
                 "Only one metric at a time is supported.")
@@ -436,6 +439,7 @@ def get_custom_functions_list(compile_params):
     if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
         custom_fn_list.append(local_loss)
     if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-        custom_fn_list.append(local_metric)
+        if 'top_k_categorical_accuracy' not in local_metric:
+            custom_fn_list.append(local_metric)
 
     return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index ddfcc8d..520b9c9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -31,107 +31,130 @@ m4_include(`SQLCommon.m4')
 )
 
 /* Test successful table creation where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
 
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be INTEGER type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'id';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA', 'object column should be BYTEA type' )
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'object';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
     'name column should be TEXT type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'name';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
     'description column should be TEXT type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'description';
 
 /*  id should be 1 */
 SELECT assert(id = 1, 'Wrong id written by load_custom_function')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 
 /* Validate function object created */
 SELECT assert(read_custom_function(object, 2, 3) = 5, 'Custom function should return sum of args.')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 
 /* Test custom function insertion where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
 SELECT assert(name = 'sum_fn', 'Custom function sum_fn found in table.')
-    FROM test_custom_function_table WHERE id = 1;
+    FROM __test_custom_function_table__ WHERE id = 1;
 SELECT assert(name = 'sum_fn1', 'Custom function sum_fn1 found in table.')
-    FROM test_custom_function_table WHERE id = 2;
+    FROM __test_custom_function_table__ WHERE id = 2;
 
 /* Test adding an existing function name should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-    SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+    SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
     $TRAP$) = 1, 'Should error out for duplicate function name');
 
 /* Test deletion by id where valid table exists */
 /* Assert id exists before deleting */
 SELECT assert(COUNT(id) = 1, 'id 2 should exist before deletion!')
-    FROM test_custom_function_table WHERE id = 2;
-SELECT delete_custom_function('test_custom_function_table', 2);
+    FROM __test_custom_function_table__ WHERE id = 2;
+SELECT delete_custom_function('__test_custom_function_table__', 2);
 SELECT assert(COUNT(id) = 0, 'id 2 should have been deleted!')
-    FROM test_custom_function_table WHERE id = 2;
+    FROM __test_custom_function_table__ WHERE id = 2;
 
 /* Test deletion by name where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
 /* Assert id exists before deleting */
 SELECT assert(COUNT(id) = 1, 'function name sum_fn1 should exist before deletion!')
-    FROM test_custom_function_table WHERE name = 'sum_fn1';
-SELECT delete_custom_function('test_custom_function_table', 'sum_fn1');
+    FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
+SELECT delete_custom_function('__test_custom_function_table__', 'sum_fn1');
 SELECT assert(COUNT(id) = 0, 'function name sum_fn1 should have been deleted!')
-    FROM test_custom_function_table WHERE name = 'sum_fn1';
+    FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
 
 /* Test deleting an already deleted entry should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-    SELECT delete_custom_function('test_custom_function_table', 2);
+    SELECT delete_custom_function('__test_custom_function_table__', 2);
     $TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
 
 /* Test delete drops the table after deleting last entry*/
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-SELECT delete_custom_function('test_custom_function_table', 1);
-SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.')
-    FROM pg_class where relname='test_custom_function_table';
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+SELECT delete_custom_function('__test_custom_function_table__', 1);
+SELECT assert(COUNT(relname) = 0, 'Table __test_custom_function_table__ should have been deleted.')
+    FROM pg_class where relname='__test_custom_function_table__';
 
 /* Test deletion where empty table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-DELETE FROM test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+DELETE FROM __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
     'Deleting function in an empty table should generate an exception.');
 
 /* Test deletion where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
               'Deleting a non-existent table should raise exception.');
 
 /* Test where invalid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-ALTER TABLE test_custom_function_table DROP COLUMN id;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 2)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+ALTER TABLE __test_custom_function_table__ DROP COLUMN id;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 2)$$) = 1,
     'Deleting an invalid table should generate an exception.');
 
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
     'Passing an invalid table to load_custom_function() should raise exception.');
 
 /* Test input validation */
-DROP TABLE IF EXISTS test_custom_function_table;
+DROP TABLE IF EXISTS __test_custom_function_table__;
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', custom_function_object(), NULL, NULL);
+  SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), NULL, NULL);
 $$) = 1, 'Name cannot be NULL');
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', NULL, 'sum_fn', NULL);
+  SELECT load_custom_function('__test_custom_function_table__', NULL, 'sum_fn', NULL);
 $$) = 1, 'Function object cannot be NULL');
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', 'invalid_obj'::bytea, 'sum_fn', NULL);
+  SELECT load_custom_function('__test_custom_function_table__', 'invalid_obj'::bytea, 'sum_fn', NULL);
 $$) = 1, 'Invalid custom function object');
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', NULL);
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', NULL);
 SELECT assert(name IS NOT NULL AND description IS NULL, 'validate name is not NULL.')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT delete_custom_function('test_custom_function_table', NULL);
+  SELECT delete_custom_function('__test_custom_function_table__', NULL);
 $$) = 1, 'id/name cannot be NULL!');
+
+/* Test top n accuracy */
+
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 3);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 7);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 4);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 8);
+
+SELECT assert(count(*) = 4, 'Table __test_custom_function_table__ should have 4 entries')
+FROM __test_custom_function_table__;
+
+SELECT assert(name = 'top_3_accuracy', 'Top 3 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 1;
+
+SELECT assert(name = 'top_7_accuracy', 'Top 7 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 2;
+
+SELECT assert(name = 'top_4_accuracy', 'Top 4 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 3;
+
+SELECT assert(name = 'top_8_accuracy', 'Top 8 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 4;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index fecd19f..b002550 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -175,12 +175,14 @@ SELECT madlib_keras_fit(
     'test_custom_function_table'
 );
 DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+-- Test for load_top_k_accuracy with a custom k value
+SELECT load_top_k_accuracy_function('test_custom_function_table', 3);
 SELECT madlib_keras_fit(
     'iris_data_packed',
     'iris_model',
     'iris_model_arch',
     1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3,
     FALSE, NULL, 1, NULL, NULL, NULL,
@@ -203,13 +205,13 @@ SELECT assert(
         object_table = 'test_custom_function_table' AND
         model_size > 0 AND
         madlib_version is NOT NULL AND
-        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text AND
         fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
         num_iterations = 3 AND
         metrics_compute_frequency = 1 AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_3_accuracy}' AND
         array_upper(training_metrics, 1) = 3 AND
         training_loss = '{0,0,0}' AND
         array_upper(metrics_elapsed_time, 1) = 3 ,
@@ -230,7 +232,7 @@ SELECT madlib_keras_evaluate(
 
 SELECT assert(loss >= 0 AND
         metric >= 0 AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_3_accuracy}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index c4c0315..b9b775c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -166,21 +166,21 @@ SELECT assert(loss >= 0 AND
         metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 
--- TEST custom loss function
+-- TEST custom loss function and
 
 DROP TABLE IF EXISTS test_custom_function_table;
 SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
-SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
 
 -- Prepare model selection table with four rows
 DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_top_k_accuracy_function('test_custom_function_table', 4);
 SELECT load_model_selection_table(
     'iris_model_arch',
     'mst_object_table',
     ARRAY[1],
     ARRAY[
         $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
-        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$
+        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['top_4_accuracy']$$
     ],
     ARRAY[
         $$batch_size=16, epochs=1$$
@@ -222,7 +222,7 @@ SELECT assert(
         model_type = 'madlib_keras' AND
         model_size > 0 AND
         fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_4_accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  = 0  AND
         training_loss = '{0,0,0}' AND
@@ -259,7 +259,7 @@ SELECT madlib_keras_evaluate(
 
 SELECT assert(loss = 0 AND
         metric >= 0 AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_4_accuracy}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 4ccf2bd..e69bab4 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1092,6 +1092,13 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         with self.assertRaises(plpy.PLPYException):
             self.subject.parse_and_validate_compile_params(test_str)
 
+    def test_parse_and_validate_compile_params_unsupported_loss_fail(self):
+        test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
+                   "metrics=['accuracy'], loss='sparse_categorical_crossentropy'"
+
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.parse_and_validate_compile_params(test_str)
+
     def test_parse_and_validate_compile_params_dict_metrics_fail(self):
         test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
                    "loss='categorical_crossentropy', metrics={'0':'accuracy'}"