You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/11/19 00:57:54 UTC

[GitHub] [madlib] reductionista opened a new pull request #525: DL: Model Hopper Refactor

reductionista opened a new pull request #525:
URL: https://github.com/apache/madlib/pull/525


   This is a major refactor of the model hopper parallelism of the deep learning module.
   
   **A new data flow for the hop/training cycle**
   
   Background:
   
   In v1.17, MADlib creates 3 temporary tables to store the model weights while they are being
   hopped around and trained:  mst_weights_tbl, weights_to_update_tbl, and model_output_table.
   This gives rise to 3 different long-running queries that each involve at least one Redistribute Motion:
   the UDA query (which does the actual training), the UPDATE query, and the HOP query.  Most of the time in
   madlib_keras_fit_multiple() is spent in a loop over iterations and hops calling the run_training()
   function which performs each of these 3 stages, where all of the models circulate between
   these 3 temporary tables and get back to where they started after each full iteration.
   The overall effect was a lot of unnecessary motion of the model weights going back and forth 
   between different segment hosts several times for each call to run_training().  Each hop was
   really 3 or 4 hops in terms of data motion, which for a modest sized data set can cause
   the hopping phase of the cycle to take just as long as the training phase, leaving GPU's with a lot
   of idle time.
   
   3 Biggest Changes:
   
   1.  This refactor involves a major simplification where the 3 temporary tables are replaced by 2 temporary tables (model_input_tbl and model_output_tbl) and the UPDATE stage is eliminated to leave only the HOP and
   UDA stages.  There were also 3 shorter-running TRUNCATE & DROP queries in between the 3 longer queries.
   These are also reduced to only 2 TRUNCATE & DROP queries.
   
   Two additional performance improvements closely related to the above are:
   2.  A __dist_key__ column is added to model_output table and both of the temporary tables are
          DISTRIBUTED BY this key instead of mst.  The __dist_key__'s have a 1-to-1 mapping with segment_id's,
          as in the batched source table.  This serves as the hash key for all JOIN's now.
          The schedule table has both a __dist_key__ column and a __previous_dist_key__ column,
          so that it can guide the weights from the specific segment they were on previously
          to the one they are scheduled to be on for the next hop.  This avoids any additional
          Redistribute Motions caused by table Hash Join's.  With this change, there is no longer any
          movement of the models at all during except during the hop itself.  They only move
          exactly once per call to run_training().
   
   and
   3.  For model weight averaging (madlib_keras_fit()), we needed a merge and a final function
        in addition to the fit_transition function, so a UDA was the natural choice.  For the model
        hopper (madlib_keras_fit_multiple()), there is no need for merge or final, all we need is
        to call the fit_transition function directly on each row--so it was less clear whether a UDA
        was the right choice.  We found that calling it as a UDF directly, and using SD to pass
        the state from one row to the next (as we were doing already anyway with the UDA), resulted
        in slightly better performance.  So now fit_transition can be called either as a UDF or as part
        of a UDA, and fit_transition_multiple_model() is always called as a UDF.
   
       - Simplified schedule rotation: schedule table created only once, then gets
         rotated on segments, instead of re-creating many times by transfering
         data back and forth from master to segments to master each hop.  No longer
         need separate "current_schedule" and "grand_schedule" data structures.
   
   # Other performance enhancements
   
   4.  There is no need to hop the models between the last hop of the previous iteration and the first hop of the next iteration, so once per iteration we skip the hop and instead just rename model_output_tbl to model_input_tbl before truncating and dropping model_output_tbl.  For example, if there were 4 segments then in v1.17 that means 4 model hops per iteration.  In this branch, the last hop is skipped and the next iteration just starts with the models on the same segments they are already on.  The 4 hops has been reduced to only 3 hops.  The same amount of training occurs as before, and each model is paired exactly once per iteration with each segment--just in a slightly different order.
   
   5.  Much faster initialization code:  previously, we were reading the weights
         in from the original model output table (during warm start) and the model
         arch table (for transfer learning) one mst row at a time from segment to
         master, then writing them each back out one row at a time from master
         back to segments with a large number of SELECT and INSERT queries.
         Now, we just use a single query to copy the weights directly from the
         original model output table into the new model output table on the
         segments, without ever sending them to master.  And a similar single
         query copies the transfer learning weights directly from model_arch to
         model_output for training.  Both of these happen in parallel on the
         segments, instead of in sequence on master.  During testing of warm
         start on a 20-segment cluster with 20 models, this resulted in a 10x reduction
         in initialization time (26s instead of 5 mins in v1.17)
   
   6.    Split get_model_arch_and_weights() into query_weights() and get_model_arch()
           So we don't have to transfer weights from segment to master in places
           where we only need the model_arch json
   
   7.     Enables JIT XLA auto-clustering, if available.
   
   # Other Changes
   
    8.  Simplified schedule rotation: schedule table is created only once, then gets
         rotated on segments, instead of re-creating many times by transferring
         data back and forth from master to segments to master each hop.  No longer
         need separate "current_schedule" and "grand_schedule" data structures.
         These tables do not contain model weights, so there is not much of a 
         performance benefit, but it's at least simpler and more maintainable.
   
   9.   Added some debugging that can be enabled to help profile the
         performance of fit multiple, and track which segment each mst_key
         is located during each hop. This also serves as an example for
         the utils/debug PR this is rebased on top of.
   
   10.  Remove AutoML dependency on internals of fit_multiple (needed to make AutoML
        compatible with the new FitMultiple class, but also good for avoiding having to
        do the same thing for any future updates to fit_multiple.  Better modularity.
   
    11.  Improved Exception handling:  send full stack traceback from segment back to master (soon to be renamed "coordinator").  Now when an exception happens in fit_transition (or merge, final, etc.) we'll be able to see the stack trace of the internal UDF or UDA running on the segments, along with the stack trace of the outer UDF that runs on the
   coordinator.  It's attached to the DETAILS of the Exception and includes the line number where the error occurred--something that we had to guess at before, making debugging difficult.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558566450



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -134,57 +140,56 @@ class FitMultipleModel():
         self.info_str = ""
         self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
-        self.use_gpus = use_gpus
+        self.use_gpus = use_gpus if use_gpus else False
         self.segments_per_host = get_segments_per_host()
+        self.model_input_tbl = unique_string('model_input')
+        self.model_output_tbl = unique_string('model_output')
+        self.schedule_tbl = unique_string('schedule')

Review comment:
       This should be gone now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista edited a comment on pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista edited a comment on pull request #525:
URL: https://github.com/apache/madlib/pull/525#issuecomment-732492520


   Thanks for testing it out Frank!
   
   You're the first to test this on with tensorflow 1.13, and based on one of the errors you're seeing, I'm guessing you're also the first to test it with gpdb5.  I only tested it on gpdb6 and postgres.
   
   > WARNING:  This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.  (seg1 slice1 10.128.0.41:40001 pid=6271)
   > CONTEXT:  PL/Python function "fit_transition"
   > ```
   > 
   > What does user need to do to enable XLA? I am on TF 1.13.1 currently.
   
   Nothing, probably just need to upgrade to TF 1.14.  I should add the specific minimum version requirement in the warning message.  I think it was added in 1.14 but I've only used 1.10 and 1.14, so I'll need to confirm that.  Could also be that it's been in there for a while, and they just changed how you access it.  I'll look both of those up, and also run some more comprehensive performance tests to see how much of a difference it actually makes and in what kind of setups.  If it looks like it's not helping much anyway (even though the TF docs say it should) then we can get rid of this message... or maybe even not bother to enable it.  I figure if it does give better performance consistently across models, then we may as well let the user know there's an easy way to get that, without having to tune any parameters or upgrade hardware.  Would also be important to know if there are some cases where it could hurt performance.
    
   > ERROR:  plpy.Error: madlib_keras_fit_multiple_model error: No GPUs configured on hosts. (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 23, in <module>
   >     fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
   >   PL/Python function "madlib_keras_fit_multiple_model", line 147, in __init__
   >   PL/Python function "madlib_keras_fit_multiple_model", line 295, in get_accessible_gpus_for_seg
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > so it looks like `use gpus=NULL` is now defaulting to `TRUE` but it should default to `FALSE` i.e., CPUs. It used to default to CPUs.
   
   Good catch!  For some reason, I get a different error when I tried to reproduce this ( about using BOOLEAN::None in a query).  But I think I see the cause of it--fixing.  I thought we had some dev-check tests for this, will add one if we don't.
   
   > ```
   > ERROR:  plpy.SPIError: PRIMARY KEY and DISTRIBUTED BY definitions incompatible
   > HINT:  When there is both a PRIMARY KEY, and a DISTRIBUTED BY clause, the DISTRIBUTED BY clause must be equal to or a left-subset of the PRIMARY KEY
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
   >     fit_obj.fit_multiple_model()
   >   PL/Python function "madlib_keras_fit_multiple_model", line 241, in fit_multiple_model
   >   PL/Python function "madlib_keras_fit_multiple_model", line 509, in init_model_output_tbl
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > I have also seen this error when running on a single segment.
   
   Interesting!  Looks like this was due to an unnecessarily strict check in the gpdb5 server code, where it required a specific order for the keys in the PRIMARY KEY constraint.  The server team has fixed that in gpdb6:
   https://github.com/greenplum-db/gpdb/commit/e572af54e4255d501d261d51d32e1f3d3cf4b9c9
   
   I can't recall for sure if we decided we care about supporting deep learning in gpdb5 for this release.  But since it's an easy fix, I'll just reverse the order of the keys so that it works on both.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560537577



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
##########
@@ -1793,13 +1793,23 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     segments_per_host           INTEGER,
     images_per_seg              INTEGER[],
     use_gpus                    BOOLEAN,
-    accessible_gpus_for_seg                INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
     is_final_iteration          BOOLEAN,
-    custom_function_map        BYTEA
+    custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail

Review comment:
       Good suggestions, I think we should take care of this in a future PR. This looks good for now




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538740058



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       Actually, I decided I'm going to leave it `dist_key_col` here, but add a distribution rule for the rotate_schedule_tbl query to DISTRIBUTE BY `prev_dist_key`, as you suggested in another comment.
   
   The first hop is always just a table rename, so the initial schedule table never actually gets used for the hop itself.  It's only used to generate the next schedule table in the rotate_schedule_tbl query.  And in that query, it's the `dist_key_col` that matters... `prev_dist_key_col` is unused, I guess we really don't need that column to exist at all.. it's just there as a placeholder so that the schema is the same as the next schedule table that gets generated later by `rotate_schedule_tbl()`.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538747152



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -75,8 +79,7 @@ segment.
 Note that this function is disabled for Postgres.
 """
 
-@MinWarning("warning")

Review comment:
       Thanks, I think I removed it temporarily for debugging and forgot to add it back :-) 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538744805



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};

Review comment:
         I'm not sure whether the DISTRIBUTE BY in the `rotate_schedule_tbl` actually matters... if I recall, in the hop query, there is one Redistribute motion based on the dist key and another one based on the prev dist key... so I left it off figuring either way this small amount of data is going to have to move.  It's the movement of the other tables (input & output) we care about since they're large.  I haven't checked today, but it's possible the schedule table will just get broadcasted anyway, since it's much smaller than the model_input and model_output tables.  But I am going to add a DISTRIBUTED BY `prev_dist_key_col`, because that seems like the safest of the 3 options, in case the planner or orca gets changed in the future and it starts picking a bad plan... choosing either `prev_dist_key_col` or `dist_key_col` at least seems better than letting it randomly distribute.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537809115



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
##########
@@ -1793,13 +1793,23 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     segments_per_host           INTEGER,
     images_per_seg              INTEGER[],
     use_gpus                    BOOLEAN,
-    accessible_gpus_for_seg                INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
     is_final_iteration          BOOLEAN,
-    custom_function_map        BYTEA
+    custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail

Review comment:
       Good point.
   
     Originally I was thinking I would make a simple general mechanism for doing this for any python function running on the segments that we're calling from another python function on master with `plpy.execute()`.  But the story had already had gotten long enough that I ended up just leaving it kind of hacked together, with duplicated code reused in both fit and fit multiple.
   
     Eventually, I think we should move it into `utilities` somewhere, but I haven't thought it through much.  Come to think of it, maybe I should move this over into `utilities/debug.py` and have it as a part of that PR instead of this one? 
    It could be enabled with a flag passed to plpy_execute() I guess.  Or maybe even just automatically enabled for that function, since it shouldn't affect anything if the `TransAggDetail` (or whatever keyword we choose) doesn't come back with the exception.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #525:
URL: https://github.com/apache/madlib/pull/525#issuecomment-736941725


   
   re-testing...
   
   (1)
   initial tests for functionality - keras_fit()
   
   After upgrading to TF 1.14 and re-running a test with XLA auto-cluster JIT optimization, I did not see the warning message anymore.  Also, I saw 30% training time improvement from 237 sec/iteration to 166 sec/iteration for the query, which is great.
   
   
   (2)
   initial tests for functionality - keras_fit_multiple_model()
   
   first I started with single segment:
   
   This query does not error on GPUs but still does error out after several minutes.  Also, please turn off the verbose output:
   ```
   NOTICE:  CREATE TABLE / PRIMARY KEY will create implicit index "__madlib_temp_model_output49543509_1606874577_33655832___pkey" for table "__madlib_temp_model_output49543509_1606874577_33655832__"
   CONTEXT:  SQL statement "
                                       CREATE TABLE __madlib_temp_model_output49543509_1606874577_33655832__
                                       (mst_key INTEGER,
                                        model_weights BYTEA,
                                        model_arch JSON,
                                        compile_params TEXT,
                                        fit_params TEXT,
                                        object_map BYTEA,
                                        __dist_key__ INTEGER,
                                        PRIMARY KEY (__dist_key__, mst_key)
                                       )
                                       DISTRIBUTED BY (__dist_key__)
                                       "
   PL/Python function "madlib_keras_fit_multiple_model"
   NOTICE:  CREATE TABLE / PRIMARY KEY will create implicit index "cifar10_multi_model_info_pkey" for table "cifar10_multi_model_info"
   CONTEXT:  SQL statement "
               CREATE TABLE cifar10_multi_model_info (
                   mst_key INTEGER PRIMARY KEY,
                   model_id INTEGER,
                   compile_params TEXT,
                   fit_params TEXT,
                   model_type TEXT,
                   model_size DOUBLE PRECISION,
                   metrics_elapsed_time DOUBLE PRECISION[],
                   metrics_type TEXT[],
                   loss_type TEXT,
                   training_metrics_final DOUBLE PRECISION,
                   training_loss_final DOUBLE PRECISION,
                   training_metrics DOUBLE PRECISION[],
                   training_loss DOUBLE PRECISION[],
                   validation_metrics_final DOUBLE PRECISION,
                   validation_loss_final DOUBLE PRECISION,
                   validation_metrics DOUBLE PRECISION[],
                   validation_loss DOUBLE PRECISION[]
              ) "
   PL/Python function "madlib_keras_fit_multiple_model"
   NOTICE:  Table doesn't have 'DISTRIBUTED BY' clause -- Using column(s) named 'mst_key' as the Greenplum Database data distribution key for this table.
   HINT:  The 'DISTRIBUTED BY' clause determines the distribution of data. Make sure column(s) chosen are the optimal data distribution key to minimize skew.
   CONTEXT:  SQL statement "
                   CREATE TABLE __madlib_temp_next_schedule42211658_1606874876_919189__ AS
                       SELECT
                           mst_key,
                           model_id,
                           __dist_key__ AS __prev_dist_key__,
                           COALESCE(
                               LEAD(__dist_key__)
                                   OVER(ORDER BY __dist_key__),
                               FIRST_VALUE(__dist_key__)
                                   OVER(ORDER BY __dist_key__)
                           ) AS __dist_key__
                       FROM __madlib_temp_schedule9385959_1606874577_2407209__;
               "
   PL/Python function "madlib_keras_fit_multiple_model"
   ERROR:  plpy.SPIError: AttributeError: '_TfDeviceCaptureOp' object has no attribute '_set_device_from_string' (plpython.c:5038)  (seg0 slice1 10.128.0.41:40000 pid=9278) (plpython.c:5038)
   DETAIL:  Message skipped due to incorrect encoding.
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
       fit_obj.fit_multiple_model()
     PL/Python function "madlib_keras_fit_multiple_model", line 247, in fit_multiple_model
     PL/Python function "madlib_keras_fit_multiple_model", line 288, in train_multiple_model
     PL/Python function "madlib_keras_fit_multiple_model", line 929, in run_training
   PL/Python function "madlib_keras_fit_multiple_model"
   ```
   
   (3)
   initial tests for functionality - keras_fit_multiple_model()
   
   next I used 2 segments:
   
   This query errors out after several minutes.
   
   
   ```
   NOTICE:  CREATE TABLE / PRIMARY KEY will create implicit index "__madlib_temp_model_output49543509_1606875004_33656259___pkey" for table "__madlib_temp_model_output49543509_1606875004_33656259__"
   CONTEXT:  SQL statement "
                                       CREATE TABLE __madlib_temp_model_output49543509_1606875004_33656259__
                                       (mst_key INTEGER,
                                        model_weights BYTEA,
                                        model_arch JSON,
                                        compile_params TEXT,
                                        fit_params TEXT,
                                        object_map BYTEA,
                                        __dist_key__ INTEGER,
                                        PRIMARY KEY (__dist_key__, mst_key)
                                       )
                                       DISTRIBUTED BY (__dist_key__)
                                       "
   PL/Python function "madlib_keras_fit_multiple_model"
   NOTICE:  CREATE TABLE / PRIMARY KEY will create implicit index "cifar10_multi_model_info_pkey" for table "cifar10_multi_model_info"
   CONTEXT:  SQL statement "
               CREATE TABLE cifar10_multi_model_info (
                   mst_key INTEGER PRIMARY KEY,
                   model_id INTEGER,
                   compile_params TEXT,
                   fit_params TEXT,
                   model_type TEXT,
                   model_size DOUBLE PRECISION,
                   metrics_elapsed_time DOUBLE PRECISION[],
                   metrics_type TEXT[],
                   loss_type TEXT,
                   training_metrics_final DOUBLE PRECISION,
                   training_loss_final DOUBLE PRECISION,
                   training_metrics DOUBLE PRECISION[],
                   training_loss DOUBLE PRECISION[],
                   validation_metrics_final DOUBLE PRECISION,
                   validation_loss_final DOUBLE PRECISION,
                   validation_metrics DOUBLE PRECISION[],
                   validation_loss DOUBLE PRECISION[]
              ) "
   PL/Python function "madlib_keras_fit_multiple_model"
   ERROR:  plpy.SPIError: AttributeError: '_TfDeviceCaptureOp' object has no attribute '_set_device_from_string' (plpython.c:5038)  (seg0 slice1 10.128.0.41:40000 pid=9489) (plpython.c:5038)
   DETAIL:  :
   Traceback (most recent call last):
     File "<string>", line 15, in __plpython_procedure_fit_transition_multiple_model_1572293
     File "/home/gpadmin/madlib/build/src/ports/greenplum/5/modules/deep_learning/madlib_keras.py", line 545, in fit_transition
       custom_function_map)
     File "/home/gpadmin/madlib/build/src/ports/greenplum/5/modules/deep_learning/madlib_keras.py", line 90, in get_init_model_and_sess
       segment_model = init_model(model_architecture, compile_params, custom_function_map)
     File "/home/gpadmin/madlib/build/src/ports/greenplum/5/modules/deep_learning/madlib_keras.py", line 497, in init_model
       segment_model = model_from_json(model_architecture)
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/engine/saving.py", line 492, in model_from_json
       return deserialize(config, custom_objects=custom_objects)
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 55, in deserialize
       printable_module_name='layer')
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
       list(custom_objects.items())))
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/engine/sequential.py", line 301, in from_config
       model.add(layer)
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/engine/sequential.py", line 181, in add
       output_tensor = layer(self.outputs[0])
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/engine/base_layer.py", line 457, in __call__
       output = self.call(inputs, **kwargs)
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/layers/normalization.py", line 185, in call
       epsilon=self.epsilon)
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 1858, in normalize_batch_in_training
       if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 291, in _has_nchw_support
       explicitly_on_cpu = _is_current_explicit_device('CPU')
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 266, in _is_current_explicit_device
       device = _get_current_tf_device()
     File "/home/gpadmin/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 247, in _get_current_tf_device
       g._apply_device_functions(op)
     File "/usr/local/greenplum-db/ext/python/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 4581, in _apply_device_functions
       op._set_device_from_string(device_string)
   AttributeError: '_TfDeviceCaptureOp' object has no attribute '_set_device_from_string'
   Traceback (most recent call last):
     PL/Python function "fit_transition_multiple_model", line 20, in <module>
       raise e
   PL/Python function "fit_transition_multiple_model"
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
       fit_obj.fit_multiple_model()
     PL/Python function "madlib_keras_fit_multiple_model", line 247, in fit_multiple_model
     PL/Python function "madlib_keras_fit_multiple_model", line 288, in train_multiple_model
     PL/Python function "madlib_keras_fit_multiple_model", line 929, in run_training
   PL/Python function "madlib_keras_fit_multiple_model"
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537800172



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS

Review comment:
       I'd consider changing it if there's an argument for the RENAME being faster, but I'm not sure it would be... I imagine the VACUUM might take just as long.  Either way, there shouldn't be any motion across segments since the DISTRIBUTED BY is the same on both.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #525:
URL: https://github.com/apache/madlib/pull/525#issuecomment-731350733


   (1)
   initial tests for functionality - keras_fit()
   
   ```
   DROP TABLE IF EXISTS cifar_10_model, cifar_10_model_summary;
   SELECT madlib.madlib_keras_fit('cifar_10_train_data_packed_allseg',    -- source table
                                  'cifar_10_model',                -- model output table
                                  'model_arch_library',            -- model arch table
                                   1,                              -- model arch id
                                   $$ loss='categorical_crossentropy', optimizer='rmsprop(lr=0.0001, decay=1e-6)', metrics=['accuracy']$$,  -- compile_params
                                   $$ batch_size=32, epochs=3 $$,  -- fit_params
                                   3,                              -- num_iterations
                                   NULL,                          -- use GPUs
                                   'cifar_10_test_data_packed_allseg',    -- validation dataset 
                                   1                               -- metrics compute frequency 
                                 ); 
   ```
   produces warning:
   
   ```
   WARNING:  This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.  (seg0 slice1 10.128.0.41:40000 pid=6270)
   CONTEXT:  PL/Python function "fit_transition"
   WARNING:  This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.  (seg1 slice1 10.128.0.41:40001 pid=6271)
   CONTEXT:  PL/Python function "fit_transition"
   ```
   
   What does user need to do to enable XLA?  I am on TF 1.13.1 currently.
   
   Otherwise this ran and also warm start seemed to work.
   
   
   (2)
   initial tests for functionality - keras_fit_multiple_model()
   
   first I started with single segment:
   ```
   SELECT __dist_key__, independent_var_shape, dependent_var_shape, buffer_id FROM cifar_10_train_data_packed ORDER BY __dist_key__;
    __dist_key__ | independent_var_shape | dependent_var_shape | buffer_id 
   --------------+-----------------------+---------------------+-----------
               1 | {16667,32,32,3}       | {16667,10}          |         0
               1 | {16666,32,32,3}       | {16666,10}          |         2
               1 | {16667,32,32,3}       | {16667,10}          |         1
   (3 rows)
   
   SELECT __dist_key__, independent_var_shape, dependent_var_shape, buffer_id FROM cifar_10_test_data_packed ORDER BY __dist_key__;
    __dist_key__ | independent_var_shape | dependent_var_shape | buffer_id 
   --------------+-----------------------+---------------------+-----------
               1 | {10000,32,32,3}       | {10000,10}          |         0
   (1 row)
   ```
   
   run multi fit:
   ```
   DROP TABLE IF EXISTS cifar10_multi_model, cifar10_multi_model_summary, cifar10_multi_model_info;
   SELECT madlib.madlib_keras_fit_multiple_model('cifar_10_train_data_packed',    -- source_table
                                                 'cifar10_multi_model',     -- model_output_table
                                                 'mst_table',               -- model_selection_table
                                                  3,                       -- num_iterations
                                                  NULL,                     -- use gpus
                                                 'cifar_10_test_data_packed',      -- validation dataset
                                                  1,                         -- metrics compute frequency
                                                  NULL,                      -- warm_start
                                                  'me',
                                                  'this is a test run'
                                                );
   ```
   
   produces error:
   ```
   ERROR:  plpy.Error: madlib_keras_fit_multiple_model error: No GPUs configured on hosts. (plpython.c:5038)
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_fit_multiple_model", line 23, in <module>
       fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
     PL/Python function "madlib_keras_fit_multiple_model", line 147, in __init__
     PL/Python function "madlib_keras_fit_multiple_model", line 295, in get_accessible_gpus_for_seg
   PL/Python function "madlib_keras_fit_multiple_model"
   
   ```
   
   so it looks like `use gpus=NULL` is now defaulting to `TRUE` but it should default to `FALSE` i.e., CPUs.  It used to default to CPUs.
   
   
   (3)
   initial tests for functionality - keras_fit_multiple_model()
   
   next I used 2 segments:
   ```
   SELECT __dist_key__, independent_var_shape, dependent_var_shape, buffer_id FROM cifar_10_train_data_packed_allseg ORDER BY __dist_key__;
    __dist_key__ | independent_var_shape | dependent_var_shape | buffer_id 
   --------------+-----------------------+---------------------+-----------
               0 | {12500,32,32,3}       | {12500,10}          |         3
               0 | {12500,32,32,3}       | {12500,10}          |         1
               1 | {12500,32,32,3}       | {12500,10}          |         2
               1 | {12500,32,32,3}       | {12500,10}          |         0
   (4 rows)
   
   SELECT __dist_key__, independent_var_shape, dependent_var_shape, buffer_id FROM cifar_10_test_data_packed_allseg ORDER BY __dist_key__;
    __dist_key__ | independent_var_shape | dependent_var_shape | buffer_id 
   --------------+-----------------------+---------------------+-----------
               0 | {5000,32,32,3}        | {5000,10}           |         1
               1 | {5000,32,32,3}        | {5000,10}           |         0
   (2 rows)
   ```
   
   run multi fit:
   ```
   DROP TABLE IF EXISTS cifar10_multi_model, cifar10_multi_model_summary, cifar10_multi_model_info;
   SELECT madlib.madlib_keras_fit_multiple_model('cifar_10_train_data_packed_allseg',    -- source_table
                                                 'cifar10_multi_model',     -- model_output_table
                                                 'mst_table',               -- model_selection_table
                                                  3,                       -- num_iterations
                                                  NULL,                     -- use gpus
                                                 'cifar_10_test_data_packed_allseg',      -- validation dataset
                                                  1,                         -- metrics compute frequency
                                                  NULL,                      -- warm_start
                                                  'me',
                                                  'this is a test run'
                                                );
   ```
   
   which produced error:
   ```
   ERROR:  plpy.SPIError: PRIMARY KEY and DISTRIBUTED BY definitions incompatible
   HINT:  When there is both a PRIMARY KEY, and a DISTRIBUTED BY clause, the DISTRIBUTED BY clause must be equal to or a left-subset of the PRIMARY KEY
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
       fit_obj.fit_multiple_model()
     PL/Python function "madlib_keras_fit_multiple_model", line 241, in fit_multiple_model
     PL/Python function "madlib_keras_fit_multiple_model", line 509, in init_model_output_tbl
   PL/Python function "madlib_keras_fit_multiple_model"
   ```
   
   I have also seen this error when running on a single segment.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560539401



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))

Review comment:
       It would be interesting to figure out how much extra time a subquery would add and if there is a better and efficient way to write this query. But that can be done in a future PR. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r530005710



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS

Review comment:
       why not drop the original model output table and then rename `model_output_tbl` to  `original_model_output_tbl` ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""

Review comment:
       Is this plpy.execute only for the plpy info commands ? If yes, then maybe we should consider if we really need to run this plpy.execute and the plpy info prints

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -64,6 +64,13 @@ def reset_cuda_env(value):
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
             del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
+def enable_xla():
+    os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
+    try:
+        K.tf.config.optimizer.set_jit(True)
+    except:
+        plpy.warning("This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.")

Review comment:
       maybe we should be more specific with the upgrading hint since we don't support tf 2.0 yet.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -518,111 +567,130 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with K.tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
+
     if is_last_row:
+        SD[SD_STORE.AGG_IMAGE_COUNT] = 0  # Must be reset after each pass through images
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
+
+    SD[SD_STORE.TRANS_EXIT_TIME] = trans_exit_time
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
-                             independent_var_shape, model_architecture,
-                             compile_params, fit_params, dist_key, dist_key_mapping,
-                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
-                             accessible_gpus_for_seg, prev_serialized_weights,
-                             is_final_training_call, custom_function_map=None, **kwargs):
+def fit_multiple_transition_caching(
+    dependent_var, independent_var, dependent_var_shape, independent_var_shape,
+    model_architecture, compile_params, fit_params, dist_key, dist_key_mapping,
+    current_seg_id, segments_per_host, images_per_seg, use_gpus, accessible_gpus_for_seg,
+    serialized_weights, is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
+    - weights can be passed in as None for the
       very first hop and the final training call
     - x_train, y_train and cache_set is cleared from SD for
-      final_training_call = TRUE
+      is_final_training_call = True
     """
-    if not state:
-        agg_image_count = 0
+    SD = kwargs['SD']
+
+    trans_enter_time = time.time()
+    trans_exit_time = None
+    if SD_STORE.TRANS_EXIT_TIME in SD:
+        trans_exit_time = SD[SD_STORE.TRANS_EXIT_TIME]
+        SD[SD_STORE.TRANS_EXIT_TIME] = trans_enter_time
+
+    if SD_STORE.AGG_IMAGE_COUNT in SD:
+        agg_image_count = SD[SD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
 
-    SD = kwargs['SD']
-    is_cache_set = 'cache_set' in SD
+    if agg_image_count > 0:
+        if trans_exit_time:
+            DEBUG.plpy.info("|_gpdb_btw_rows_time_|{}|".format(trans_enter_time - trans_exit_time))
+    else:
+        if trans_exit_time:
+            DEBUG.plpy.info("|_gpdb_btw_hops_end_time_|{}|".format(trans_enter_time - trans_exit_time))
 
     # Prepare the data
-    if is_cache_set:
+    if dependent_var_shape is None:
         if 'x_train' not in SD or 'y_train' not in SD:
             plpy.error("cache not populated properly.")
-        total_images = None
         is_last_row = True
+        total_images = None
     else:
-        if not independent_var or not dependent_var:
-            return state
-        if 'x_train' not in SD:
+        if 'x_train' not in SD or 'y_train' not in SD:
             SD['x_train'] = list()
             SD['y_train'] = list()
+
         agg_image_count += independent_var_shape[0]
-        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
-                                                          images_per_seg)
+        SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        total_images = get_image_count_per_seg_from_array(
+            dist_key_mapping.index(dist_key), images_per_seg
+        )
         is_last_row = agg_image_count == total_images
-        if is_last_row:
-            SD['cache_set'] = True
         x_train_current = np_array_float32(independent_var, independent_var_shape)
         y_train_current = np_array_int16(dependent_var, dependent_var_shape)
         SD['x_train'].append(x_train_current)
         SD['y_train'].append(y_train_current)
 
     # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
     # But if the weights are None, we do not want to set any model. So early return in that case
-    if prev_serialized_weights is None:
+    if serialized_weights is None:
         if is_final_training_call:
             del SD['x_train']
             del SD['y_train']
-            del SD['cache_set']
-        return float(agg_image_count)
+        return None
 
     segment_model = None
     if is_last_row:
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
-        segment_model, sess = get_init_model_and_sess(SD, device_name,
-                                                      accessible_gpus_for_seg[current_seg_id],
-                                                      segments_per_host,
-                                                      model_architecture, compile_params,
-                                                      custom_function_map)
-        set_model_weights(segment_model, prev_serialized_weights)
+        with K.tf.device(device_name):

Review comment:
       we already wrap the code with K.tf.device(device_name): inside the get_init_model_and_sess function. We don't really need to do it here as well.
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
##########
@@ -1793,13 +1793,23 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     segments_per_host           INTEGER,
     images_per_seg              INTEGER[],
     use_gpus                    BOOLEAN,
-    accessible_gpus_for_seg                INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
     is_final_iteration          BOOLEAN,
-    custom_function_map        BYTEA
+    custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail

Review comment:
       Instead of adding`TransAggDetail` to the error message, why not use a keyword that is more relevant for the fit multiple UDF ? If there is a specific reason for adding this keyword, I would recommend documenting it.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""

Review comment:
       Same as the previous comment
   
   `Is this plpy.execute only for the plpy info commands ? If yes, then maybe we should consider if we really need to run this plpy.execute and the plpy info prints`
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))

Review comment:
       Instead of running a DELETE command, can't we filter out the null model_weights when executing the udf query?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)

Review comment:
       In the current implementation of `DEBUG.plpy.execute`, we still execute the query but don't print the plan or the timing.
   It's a bit confusing to see the word DEBUG in front of plpy.execute and not assume that this query will only be executed when the debug flag is turned on.
   
   So I would suggest the following
   1. Reduce the number of DEBUG.plpy.executes to only a couple maybe just the hop and the udf query
   2. Not use the word DEBUG but something like `plpy_execute` so as to not confuse the reader

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -495,21 +500,45 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         b. keras session is cleared at the end of the final iteration,
         i.e, last row of last iteration.
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     SD = kwargs['SD']
+
+    trans_enter_time = time.time()
+
+    trans_exit_time = None
+    if SD_STORE.TRANS_EXIT_TIME in SD:
+        trans_exit_time = SD[SD_STORE.TRANS_EXIT_TIME]
+        SD[SD_STORE.TRANS_EXIT_TIME] = None
+
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
-    segment_model, sess = get_init_model_and_sess(SD, device_name,
-                                                  accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, compile_params,
-                                                  custom_function_map)
-    if not state:
+    with K.tf.device(device_name):

Review comment:
       we already wrap the code with `K.tf.device(device_name):` inside the `get_init_model_and_sess` function. We don't really need to do it here as well. 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, {self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       1. Maybe consider renaming these two functions since their only purpose is not to load the weights but also model_arch, compile_params, fit_params etc.
   1. Also consider not including the keyword `xfer_learning` in the function name since it is also used in the case where model_weights are null in the model arch table.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -81,6 +88,7 @@ def set_keras_session(device_name, gpu_count, segments_per_host):
     with K.tf.device(device_name):
         session = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(session)
+        enable_xla()

Review comment:
       enable_xla() is already called inside of `get_keras_session`. Since set_keras_session calls `get_keras_session`, we don't really need to call enable_xla() here

##########
File path: src/ports/postgres/modules/utilities/debug.py_in
##########
@@ -0,0 +1,145 @@
+import plpy as plpy_orig
+import time
+from deep_learning.madlib_keras_model_selection import ModelSelectionSchema
+from deep_learning.madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
+
+mst_key_col = ModelSelectionSchema.MST_KEY
+dist_key_col = DISTRIBUTION_KEY_COLNAME
+
+start_times = dict()
+timings_enabled = False
+
+def start_timing(msg, force=False):
+    if timings_enabled or force:
+        start_times[msg] = time.time()
+        plpy_orig.info("|_{}_time_HDR|Elapsed (s)|Current|Current (s)|Start|Start (s)|".format(msg))
+
+def print_timing(msg, force=False):
+    if timings_enabled or force:
+        try:
+            start_time = start_times[msg]
+        except:
+            raise Exception(
+                "print_timing({msg}) called with no start_timing({msg})!".format(msg=msg)
+            )
+        current_time = time.time() 
+        plpy_orig.info(
+            '|_{0}_time|{1}|{2}|{3}|{4}|{5}'.format(
+                msg,
+                current_time - start_time,
+                time.ctime(current_time),
+                current_time,
+                time.ctime(start_time),
+                start_time
+            )
+        )
+
+mst_keys_enabled = False
+def print_mst_keys(table, label, force=False):
+    if not (mst_keys_enabled or force):
+        return
+
+    res = plpy_orig.execute("""
+        SELECT gp_segment_id AS seg_id,
+               {mst_key_col},
+               {dist_key_col}
+        FROM {table} ORDER BY {dist_key_col}
+    """.format(dist_key_col=dist_key_col,
+               table=table,
+               mst_key_col=mst_key_col))
+
+    plpy_orig.info("|_MST_KEYS_{label}_HDR|mst_key|seg_id|dist_key|table".format(**locals()))
+    if not res:
+        plpy_orig.error("{table} is empty!  Aborting".format(table=table))
+
+    for r in res:
+        seg_id = r['seg_id']
+        mst_key = r['mst_key']
+        dist_key = r[dist_key_col]
+        plpy_orig.info("|_MST_KEYS_{label}|{mst_key}|{seg_id}|{dist_key}|{table}".format(**locals()))
+
+plpy_execute_enabled = False
+def plpy_execute(*args, **kwargs):
+    """ debug.plpy.execute(sql, ..., force=False)
+
+        Replace plpy.execute(sql, ...) with
+        debug.plpy.execute(sql, ...) to debug
+        a query.  Shows the query itself, the
+        EXPLAIN of it, and how long the query
+        takes to execute.
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    plpy = plpy_orig # override global plpy,
+                     # to avoid infinite recursion
+
+    if not (plpy_execute_enabled or force):
+        return plpy.execute(*args, **kwargs)
+
+    if len(args) > 0:
+        sql = args[0]
+    else:
+        raise TypeError('debug.plpy.execute() takes at least 1 parameter, 0 passed')
+
+    if type(sql) == str: # can't print if a PLyPlan object
+        plpy.info(sql)
+
+        # Print EXPLAIN of sql command
+        res = plpy.execute("EXPLAIN " + sql, *args[1:], **kwargs)
+        for r in res:
+            plpy.info(r['QUERY PLAN'])
+
+    # Run actual sql command, with timing
+    start = time.time()
+    res = plpy.execute(*args, **kwargs)
+
+    # Print how long execution of query took
+    plpy.info("Query took {0}s".format(time.time() - start))
+    if res:
+        plpy.info("Query returned {} row(s)".format(len(res)))
+    else:
+        plpy.info("Query returned 0 rows")
+    return res
+
+plpy_info_enabled = False
+def plpy_info(*args, **kwargs):
+    """ plpy_info(..., force=False)
+
+      plpy.info() if enabled, otherwise do nothing   
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    if plpy_info_enabled or force:
+        plpy_orig.info(*args, **kwargs)
+
+plpy_debug_enabled = False
+def plpy_debug(*args, **kwargs):
+    """ debug.plpy.debug(..., force=False)
+
+        Behaves like plpy.debug() if disabled (printing only
+        if DEBUG level is set high enough), but becomes a
+        plpy.info() if enabled.
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    if plpy_debug_enabled or force:
+        plpy_orig.info(*args, **kwargs)
+    else:
+        plpy_orig.debug(*args, **kwargs)
+
+class plpy:
+    execute = staticmethod(plpy_execute)
+    info = staticmethod(plpy_info)
+    debug = staticmethod(plpy_debug)

Review comment:
       Since this function is never used, do you think it makes sense to can delete this line and the `plpy_debug` function as well (given that developers can still use plpy_info for printing purposes) ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        DEBUG.plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
 
-    def create_model_output_table_warm_start(self):
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col}, {self.mst_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = DEBUG.plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        DEBUG.plpy.info("Pre-initialized mst keys: {}".format(initialized_msts))

Review comment:
       We should try to reduce the frequency of `DEBUG.plpy.info` code lines since they can be feature specific. Having a lot of these code lines might make the code slightly harder to read. 
   Any developer working on a feature can add their own `DEBUG.plpy.info` as needed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558631652



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(
+        'SELECT assert(' ||
+            '%L = %L, ' ||
+            '%L || ' ||
+            '%L || %L::TEXT || ' ||
+            '%L || %L::TEXT' ||
+        ')',
+        actual,
+        expected,
+        'mst keys found in wrong order / wrong segments!',
+        E'\nActual: ',
+        actual,
+        E'\nExpected: ',
+        expected
+    ) INTO res;
+END
+$$ LANGUAGE PLpgSQL VOLATILE;
+
+-- Create mst table
+DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'iris_mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$
+    ]
+);
+
+-- Create FitMultiple object for running test functions
+SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
+
+CREATE TABLE src_3segs AS
+    SELECT * FROM iris_data_15buf_packed
+    DISTRIBUTED BY (__dist_key__);
+
+-- Simulate 6 models on 3 segments --
+SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
+
+--=== Test init_schedule_tbl() ===--
+-- ====================================================================
+-- ===========  Enough setup, now for the actual tests! ===============
+-- ====================================================================
+
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
+    'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
+) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
+
+-- Save order of mst keys in schedule for tracking 
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+--=== Test rotate_schedule() ===--
+SELECT test_rotate_schedule('current_schedule');
+UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+SELECT validate_mst_key_order('current_schedule', 'expected_order');
+UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys);  -- Undo for later
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+SELECT validate_mst_key_order('model_output', 'expected_order');
+
+-- Order of params in run_training test function (for reference below):
+--
+--  test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
+
+--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 0, False, False, False);

Review comment:
       This is the first hop of an iteration, but not the "very first hop"... ie, the first hop of the first iteration.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r544747487



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -289,7 +327,14 @@ class FitMultipleModel():
             images_per_seg = self.images_per_seg_valid
             self.info_str += "\n\tValidation set after iteration {0}:".format(epoch)
         for mst in self.msts:
-            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            DEBUG.start_timing('eval_query_weights')
+            weights = query_weights(self.model_output_tbl, self.model_weights_col,

Review comment:
       we just need to query the model_arch, don't need to query the weights right?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r545551642



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -289,7 +327,14 @@ class FitMultipleModel():
             images_per_seg = self.images_per_seg_valid
             self.info_str += "\n\tValidation set after iteration {0}:".format(epoch)
         for mst in self.msts:
-            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            DEBUG.start_timing('eval_query_weights')
+            weights = query_weights(self.model_output_tbl, self.model_weights_col,

Review comment:
       Right, good catch!  Took it out




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556085247



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       Maybe the safest would be to check all 4 of them?  If any one of them is NULL but another of the 3 is non-NULL then that's a sign that there's something seriously wrong and we should probably just stop and error out immediately.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556079234



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       Re: checking dependent_var vs dependent_var_shape, I don't have a strong preference.  I think I went with that out of a habit for avoiding messing with any potentially large data structures in favor of small ones, if all we need is a boolean.  But I don't think it actually matters in this case--if you like dependent_var better I can change it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537838544



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, {self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       Re: 1, I'll think about it.  Any suggestions?
   
   2. It's only for transfer learning, that's actually the important part of the name.  Notice the IS NOT NULL part of the query which makes sure it's only copying the transfer learning weights.
   
   The basic logic for initializing the model output table is this:
   
   1.  Bulk load any warm start weights from previous model output table
   2.  Bulk load any transfer learning weights from model arch table
   3.  Iterate over all remaining mst_keys that haven't already been loaded, and initialize with random weights, writing them one by one with INSERT's to the model output table.
   
   We could drop the IS NOT NULL condition in 2, but since there are no weights, it would just be copying the tiny stuff (compile params, fit params, etc.)... so bulk loading doesn't really help there anyway.  And it would make it more difficult later, since we'd have to INSERT the weights in a different way for rows that already have some of the columns initialized but not others.  (Or, just write over them, in which case nothing useful was accomplished by copying them the first time.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538651733



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       We would if we wanted the weights to stay fixed... but the purpose of this query is to redistribute the weights from the previous distribution to the next distribution.  The output has to be distributed by __dist_key__ since that's what will be used as the JOIN key with the source table's __dist_key__ in the next query.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r552332253



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -134,57 +140,56 @@ class FitMultipleModel():
         self.info_str = ""
         self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
-        self.use_gpus = use_gpus
+        self.use_gpus = use_gpus if use_gpus else False
         self.segments_per_host = get_segments_per_host()
+        self.model_input_tbl = unique_string('model_input')
+        self.model_output_tbl = unique_string('model_output')
+        self.schedule_tbl = unique_string('schedule')

Review comment:
       While testing locally, I found out that there is a leftover schedule table after a successful run of fit_multiple.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r553567263



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -391,12 +422,14 @@ def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
         will only be used for segment nodes.
         @args:
             @param model_table: Output model table passed in to fit.
-            @param model_arch_result: Dict containing model architecture info.
+            @param model_arch: Dict containing model architecture info.
             @param warm_start: Boolean flag indicating warm start or not.
     """
     if is_platform_pg():
+        # Use GPU's if they are enabled
         _ = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[0], None)
     else:
+        # We are on master, so never use GPU's

Review comment:
       This comment is only valid for gpdb. For postgres, there is no master/coordinator node so we always use GPUs if available whereas on gpdb we use gpus only on segment nodes. We should update the comment to be a bit more clear

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state

Review comment:
       Shouldn't we error out if either prev_serialized_weights or model_architecture is None ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -663,17 +717,20 @@ def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in updated_model_weights]
-            new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:

Review comment:
       Now that we don't get  `agg_image_count` from the state, do we still need to set `new_state = float(agg_image_count)` at line 678/735 ?  This might help simplify the if/elif condition as well
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -338,183 +372,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
 
-    def create_model_output_table_warm_start(self):
+        create_sched_query = """
+            CREATE {self.unlogged_table} TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        plpy_execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if self.rotate_schedule_tbl_plan is None:
+            rotate_schedule_tbl_query = """
+                CREATE {self.unlogged_table} TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl}
+                DISTRIBUTED BY ({self.prev_dist_key_col})
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        plpy_execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        plpy_execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE {self.unlogged_table} TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        DEBUG.plpy.info("Pre-initialized mst keys: {}".format(initialized_msts))
+
+        # We've already bulk loaded all of the models with user-specified weights.
+        #  For the rest of the models, we need to generate the weights for each
+        #  by initializing them with keras and adding them one row at a time.
+        #
+        # TODO:  In the future, we should probably move the weight initialization
+        #  into the transition function on the segments.  Here, we would just
+        #  bulk load everything with a single query (or 2, for the warm start case),
+        #  and leave the weights column as NULL for any model whose weights need
+        #  to be randomly initialized.  Then in fit_transition, if prev_weights is
+        #  NULL, and there is nothing in GD, it should just skip the call to
+        #  set_weights(), and keras will automatically initialize them during
+        #  model.from_json(model_arch).
+        #
+        #  This would be a very easy change for fit_multiple(), but might require
+        #   some more work to support fit().  All of the segments there need to
+        #   start with the same weights, so we'd at least have to pass a random
+        #   seed to the transition function for keras to use.  Or generate a seed
+        #   on the segments in some deterministic way that's the same for all.
+        for index, mst in enumerate(self.msts_for_schedule):
+            if mst is None:
+                continue
+
+            if mst['mst_key'] in initialized_msts:
+                continue  # skip if we've already loaded this mst
+
+            num_dist_keys = len(self.dist_keys)
+
+            if index < num_dist_keys:
+                dist_key = self.dist_keys[index]
+            else:  # For models that won't be trained on first hop
+                dist_key = self.extra_dist_keys[index - num_dist_keys]
+
+            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
+            serialized_weights = get_initial_weights(None, model_arch, None, False,
+                                                     self.accessible_gpus_for_seg)
+
+            DEBUG.plpy.info(

Review comment:
       For code readability, I think we can remove this debug plpy.info statement. We can add it back for debugging as and when needed
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -197,72 +202,104 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))

Review comment:
       For code readability, I think we can remove these two debug plpy.info statements. We can add them back for debugging as and when needed

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -62,23 +68,29 @@ class GD_STORE:
     def clear(GD):
         del GD[GD_STORE.SEGMENT_MODEL]
         del GD[GD_STORE.SESS]
+        if GD_STORE.AGG_IMAGE_COUNT in GD:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
 
 def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if GD_STORE.SESS in GD:
+
+    if GD_STORE.SESS in GD :
+        if GD_STORE.AGG_IMAGE_COUNT not in GD:

Review comment:
       The logic to initialize agg_image_count is also in the `fit_transition` and `fit_multiple_transition_caching` functions. We may not have to initialize it here as well

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     GD = kwargs['GD']
+
+    trans_enter_time = time.time()
+
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
-                                                  accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, compile_params,
-                                                  custom_function_map)
-    if not state:
-        agg_image_count = 0
-        set_model_weights(segment_model, prev_serialized_weights)
+        accessible_gpus_for_seg[current_seg_id],
+        segments_per_host,
+        model_architecture, compile_params,
+        custom_function_map)
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+
+    DEBUG.plpy_info("agg_image_count={}".format(agg_image_count))

Review comment:
       I think we can remove this debug plpy.info statement. We can add it back for debugging as and when needed

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -43,16 +47,17 @@ from utilities.utilities import is_platform_pg
 from utilities.utilities import get_seg_number
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import rename_table
+import utilities.debug as DEBUG
+from utilities.debug import plpy_prepare
+from utilities.debug import plpy_execute
 
-import json
-from collections import defaultdict
-import random
-import datetime
+DEBUG.timings_enabled = True

Review comment:
       We should set this to false before merging the code

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -27,6 +27,7 @@ from madlib_keras_model_selection import ModelSelectionSchema
 from keras_model_arch_table import ModelArchSchema
 from utilities.validate_args import table_exists, drop_tables, input_tbl_valid
 from utilities.validate_args import quote_ident
+from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME

Review comment:
       Changes to the automl code are now part of PR #526 right ? 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       Isn't it better to check for the actual data column rather than the shape column to assert for missing data?
   
   Also previously we supported the case when one of the rows would have missing x/y but now it will error out. Is this intentional ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS

Review comment:
       Sure makes sense. In that case, drop and create looks good




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537838544



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, {self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       Re: 1, I'll think about it.  Any suggestions?
   
   2. Actually it's only for transfer learning, that's actually the important part of the name.  Notice the IS NOT NULL part of the query which makes sure it's only copying the transfer learning weights.
   
   The basic logic for initializing the model output table is this:
   
   1.  Bulk load any warm start weights from previous model output table
   2.  Bulk load any transfer learning weights from model arch table
   3.  Iterate over all remaining mst_keys that haven't already been loaded, and initialize with random weights, writing them one by one with INSERT's to the model output table.
   
   We could drop the IS NOT NULL condition in 2, but since there are no weights, it would just be copying the tiny stuff (compile params, fit params, etc.)... so bulk loading doesn't really help there anyway.  And it would make it more difficult later, since we'd have to INSERT the weights in a different way for rows that already have some of the columns initialized but not others.  (Or, just write over them, in which case nothing useful was accomplished by copying them the first time.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560533501



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state

Review comment:
       makes sense 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556811701



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)

Review comment:
       the err_msg variable will need a `{}` at the end for this format to work

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(
+        'SELECT assert(' ||
+            '%L = %L, ' ||
+            '%L || ' ||
+            '%L || %L::TEXT || ' ||
+            '%L || %L::TEXT' ||
+        ')',
+        actual,
+        expected,
+        'mst keys found in wrong order / wrong segments!',
+        E'\nActual: ',
+        actual,
+        E'\nExpected: ',
+        expected
+    ) INTO res;
+END
+$$ LANGUAGE PLpgSQL VOLATILE;
+
+-- Create mst table
+DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'iris_mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$
+    ]
+);
+
+-- Create FitMultiple object for running test functions
+SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
+
+CREATE TABLE src_3segs AS
+    SELECT * FROM iris_data_15buf_packed
+    DISTRIBUTED BY (__dist_key__);
+
+-- Simulate 6 models on 3 segments --
+SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
+
+--=== Test init_schedule_tbl() ===--
+-- ====================================================================
+-- ===========  Enough setup, now for the actual tests! ===============
+-- ====================================================================
+
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
+    'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
+) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
+
+-- Save order of mst keys in schedule for tracking 
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+--=== Test rotate_schedule() ===--
+SELECT test_rotate_schedule('current_schedule');
+UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+SELECT validate_mst_key_order('current_schedule', 'expected_order');
+UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys);  -- Undo for later
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+SELECT validate_mst_key_order('model_output', 'expected_order');
+
+-- Order of params in run_training test function (for reference below):
+--
+--  test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
+
+--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 0, False, False, False);

Review comment:
       If we are testing the very first hop, shouldn't the third argument be True ?

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(

Review comment:
       Do we need to call the format function for any of these queries? Can't we format it ourselves? something like 
   ```
   create table foo (i int, j int);
   insert into foo values(1,1);
   insert into foo values(2,2);
   insert into foo values(3,4);
   create or replace function test()
   returns void as
   $$
   DECLARE
       actual INTEGER[];
       expected INTEGER[];
   BEGIN
   execute 'select ARRAY(select i from foo order by i)' into actual;
   execute 'select ARRAY(select j from foo order by j)' into expected;
   PERFORM madlib.assert(actual = expected, 'mst keys found in wrong order / wrong segments! ' || E'\nActual: ' || actual::text || E'\nExpected: ' || expected::text);
   END
   $$ language PLpgSQL;
   
   madlib=# select test();
   ERROR:  Failed assertion: mst keys found in wrong order / wrong segments!
   DETAIL:
   Actual: {1,2,3}
   Expected: {1,2,4}
   CONTEXT:  SQL statement "SELECT madlib.assert(actual = expected, 'mst keys found in wrong order / wrong segments! ' || E'\nActual: ' || actual::text || E'\nExpected: ' || expected::text)"
   PL/pgSQL function test() line 8 at PERFORM
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558735883



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -62,23 +68,29 @@ class GD_STORE:
     def clear(GD):
         del GD[GD_STORE.SEGMENT_MODEL]
         del GD[GD_STORE.SESS]
+        if GD_STORE.AGG_IMAGE_COUNT in GD:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
 
 def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if GD_STORE.SESS in GD:
+
+    if GD_STORE.SESS in GD :
+        if GD_STORE.AGG_IMAGE_COUNT not in GD:

Review comment:
       ok




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538651733



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       We would if we wanted the weights to stay fixed... but the purpose of this query is to redistribute the weights from the previous distribution to the next distribution.  The output has to be distributed by __dist_key__ since that's what will be used as the JOIN key with the source table's __dist_key__ in the next query.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556101604



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -663,17 +717,20 @@ def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in updated_model_weights]
-            new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:

Review comment:
       The `new_state = float(agg_image_count)`line is only executed for `is_fit_multiple=False`.  For `is_fit_multiple=True`, we do set `new_state = None`
   
   As it stands, changing the fit branch to also return NULL will cause the `madlib_keras_fit()` tests to fail.   Since this was only supposed to be a refactor of fit_multiple, I didn't want to mess with that code.  But maybe we could get rid of it for `fit()` as well.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560506480



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       you are right that x or y being null probably should be a case for exception since our minibatch code should have minibatched everything. 
   For dependent_var vs dependent_var_shape, I'll leave it up to you to choose which one to check for (I'm leaning towards dependent_var)

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(
+        'SELECT assert(' ||
+            '%L = %L, ' ||
+            '%L || ' ||
+            '%L || %L::TEXT || ' ||
+            '%L || %L::TEXT' ||
+        ')',
+        actual,
+        expected,
+        'mst keys found in wrong order / wrong segments!',
+        E'\nActual: ',
+        actual,
+        E'\nExpected: ',
+        expected
+    ) INTO res;
+END
+$$ LANGUAGE PLpgSQL VOLATILE;
+
+-- Create mst table
+DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'iris_mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$
+    ]
+);
+
+-- Create FitMultiple object for running test functions
+SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table');
+
+CREATE TABLE src_3segs AS
+    SELECT * FROM iris_data_15buf_packed
+    DISTRIBUTED BY (__dist_key__);
+
+-- Simulate 6 models on 3 segments --
+SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings');
+
+--=== Test init_schedule_tbl() ===--
+-- ====================================================================
+-- ===========  Enough setup, now for the actual tests! ===============
+-- ====================================================================
+
+SELECT test_init_schedule('current_schedule');
+SELECT assert(
+    s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL,
+    'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table'
+) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key);
+
+-- Save order of mst keys in schedule for tracking 
+DROP TABLE IF EXISTS expected_order;
+CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys;
+
+--=== Test rotate_schedule() ===--
+SELECT test_rotate_schedule('current_schedule');
+UPDATE expected_order SET mst_keys=rotate_keys(mst_keys);
+SELECT validate_mst_key_order('current_schedule', 'expected_order');
+UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys);  -- Undo for later
+
+-- Initialize model_output table, and set model_input & cached_src table names
+SELECT setup_model_tables('model_input', 'model_output', 'cached_src');
+SELECT validate_mst_key_order('model_output', 'expected_order');
+
+-- Order of params in run_training test function (for reference below):
+--
+--  test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching)
+
+--=== Test first hop of an iteration - no caching (# msts > # segs) ===--
+SELECT test_run_training('src_3segs', 0, False, False, False);

Review comment:
       I see makes sense. Do you think it makes sense to add a test case for "very first hop" with caching disabled ?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538674096



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       Oh, I responded at first thinking this was the hop query.
   
   The distribution of this table shouldn't matter much, since there are no weights in it... but yes, I think you're right, I'll change it to prev dist key.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537797983



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS

Review comment:
       I did try that, but it gets a bit messy.
   
   There are a 3 differences between the schemas:
      1. The temporary model table has 3 extra columns:  compile_params, fit_params, and object_map columns
      2. The temporary model has a PRIMARY KEY constraint (a combination of __dist_key__, mst_key), which means it also has a SEQUENCE relation that exists in the same schema with a similar name.
      3. The temporary table is an unlogged table, while presumably the user expects the output to be an ordinary logged table.
   
   The first we could fix by DROP'ing the 3 extra columns after the ALTER TABLE RENAME.  But even then, there's still a minor issue with it:  the user is left with a table which is probably fragmented on disk in a weird way... with the extra columns probably still there but marked as deleted.  So to clean that up, we'd probably want to run VACUUM or something on it, to have it consolidate the table a bit so it's not taking up a larger footprint than it would.
   
   The second one isn't too bad, but also requires extra work:  if you do the table rename, it doesn't rename the SEQUENCE, and gets errors when you try to access it.  So you either have to do another RENAME on the SEQUENCE, or just DROP it.
   
   The third one I see as the most problematic, unless there is an easy way to change it from UNLOGGED to LOGGED.  Seems like an undesirable outcome the user might not expect.
   
   I didn't get as far as looking up how to fix the 3rd one, since at that point, it just seemed simpler and cleaner to DROP the temp table and create a nice fresh table as output that we know isn't going to have any issues.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on pull request #525:
URL: https://github.com/apache/madlib/pull/525#issuecomment-732492520


   Thanks for testing it out Frank!
   
   You're the first to test this on with tensorflow 1.13, and based on one of the errors you're seeing, I'm guessing you're also the first to test it with gpdb5.  I only tested it on gpdb6 and postgres.
   
   > WARNING:  This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.  (seg1 slice1 10.128.0.41:40001 pid=6271)
   > CONTEXT:  PL/Python function "fit_transition"
   > ```
   > 
   > What does user need to do to enable XLA? I am on TF 1.13.1 currently.
   
   Nothing, probably just need to upgrade to TF 1.14.  I should add the specific minimum version requirement in the warning message.  I think it was added in 1.14 but I've only used 1.10 and 1.14, so I'll need to confirm that.  Could also be that it's been in there for a while, and they just changed how you access it.  I'll look both of those up, and also run some more comprehensive performance tests to see how much of a difference it actually makes and in what kind of setups.  If it looks like it's not helping much anyway (even though the TF docs say it should) then we can get rid of this message... or maybe even not bother to enable it.  I figure if it does give better performance consistently across models, then we may as well let the user know there's an easy way to get that, without having to tune any parameters or upgrade hardware.  Would also be important to know if there are some cases where it could hurt performance.
    
   > ERROR:  plpy.Error: madlib_keras_fit_multiple_model error: No GPUs configured on hosts. (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 23, in <module>
   >     fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
   >   PL/Python function "madlib_keras_fit_multiple_model", line 147, in __init__
   >   PL/Python function "madlib_keras_fit_multiple_model", line 295, in get_accessible_gpus_for_seg
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > so it looks like `use gpus=NULL` is now defaulting to `TRUE` but it should default to `FALSE` i.e., CPUs. It used to default to CPUs.
   
   Good catch!  For some reason, I get a different error when I tried to reproduce this ( about using BOOLEAN::None in a query).  But I think I see the cause of it--fixing.  I thought we had some dev-check tests for this, will add one if we don't.
   
   > ```
   > ERROR:  plpy.SPIError: PRIMARY KEY and DISTRIBUTED BY definitions incompatible
   > HINT:  When there is both a PRIMARY KEY, and a DISTRIBUTED BY clause, the DISTRIBUTED BY clause must be equal to or a left-subset of the PRIMARY KEY
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
   >     fit_obj.fit_multiple_model()
   >   PL/Python function "madlib_keras_fit_multiple_model", line 241, in fit_multiple_model
   >   PL/Python function "madlib_keras_fit_multiple_model", line 509, in init_model_output_tbl
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > I have also seen this error when running on a single segment.
   
   Interesting!  Looks like this was due to an unreasonably strict check in the gpdb5 server code, where it required a specific order for the keys in the PRIMARY KEY constraint.  The server team has fixed that in gpdb6:
   https://github.com/greenplum-db/gpdb/commit/e572af54e4255d501d261d51d32e1f3d3cf4b9c9
   
   I can't recall for sure if we decided we care about supporting deep learning in gpdb5 for this release.  But since it's an easy fix, I'll just reverse the order of the keys so that it works on both.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556801794



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(

Review comment:
       I am wondering if CREATE/REPLACE here for this function would cause issues.
   Does this  function  get dropped once this test is over? If not, then for the tests that run after this whenever we call `fit_transition_multiple_model()`, won’t it refer to this function instead of the actual madlib.fit_transition_multiple_model()?
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537825924



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))

Review comment:
       I was also thinking at first that we could do it that way, using a WHERE clause.  But unfortunately, that can only be used to filter the input rows of a query, not the output rows.  The only way to filter output rows I can think of is to make this query a sub-query of an outer query that does the filtering.
   
   In terms of performance, adding a subquery would be pretty similar to the way it is now, I think, except that it would run the filtering operation as a separate slice... that mostly still has to wait for the first one to complete in order to run.  And I think we want to avoid any extra slices for this one, in case it messes up SD/GD.
   
   I also wonder if we could use a GROUP BY with a HAVING clause.  But that might also add another slice, and/or slightly reduce performance... possibly for the same reason as the UDA ran slightly slower.
   
   The DELETE operation is really fast, since it just scans through the rows and marks some as deleted... the only undesirable effect of this is that it leaves the rows fragmented on disk, but since we're about to TRUNCATE AND DROP it anyway, seems like it shouldn't matter.
   
   We could also change the hop query to filter them out with a WHERE clause when it's copying from `model_output_tbl` to `model_input_tbl`.  But then I guess we'd have to do something similar for eval.  And essentially, this would just be another way of doing the same thing as the DELETE accomplishes, telling them both "ignore these rows".
   
   If you have any other ideas, let me know and I can try it.  (Perhaps using WITH keyword?  My impression has always been that WITH is just another way to introduce a separate sub-query / slice, but not sure.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556087855



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -27,6 +27,7 @@ from madlib_keras_model_selection import ModelSelectionSchema
 from keras_model_arch_table import ModelArchSchema
 from utilities.validate_args import table_exists, drop_tables, input_tbl_valid
 from utilities.validate_args import quote_ident
+from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME

Review comment:
       Right




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r556075238



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       As I understand it, we require the input table to be minibatched--so every input row must contain an array with the same dimensions.  And where the sizes of all but the first dimension of the arrays must match between all rows.  A NULL value for a class should be okay, and maybe even a NULL image somewhere in the array (although I'm not sure what that would mean), but I think if either the x or y array itself were NULL, or the shape is NULL, that's an invalid input since it wasn't minibatched properly.
   
   I'm not sure if we explicitly error out in the validation code during initialization in that case.  I figured if we were allowing that before is was probably unintentional.  Is there a reason why we would need to support that?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537838544



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, {self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       Re: 1, I'll think about it.  Any suggestions?
   
   2. It's only for transfer learning, that's actually the important part of the name.  Notice the IS NOT NULL part of the query which makes sure it's only copying the transfer learning weights.
   
   The steps for initializing the model output table are:
   
   1.  Bulk load any warm start weights from previous model output table
   2.  Bulk load any transfer learning weights from model arch table
   3.  Iterate over all remaining mst_keys that haven't already been loaded, and initialize with random weights, writing them one by one with INSERT's to the model output table.
   
   We could drop the IS NOT NULL condition in 2, but since there are no weights, it would just be copying the tiny stuff (compile params, fit params, etc.)... so bulk loading doesn't really help there anyway.  And it would make it more difficult later, since we'd have to INSERT the weights in a different way for rows that already have some of the columns initialized but not others.  (Or, just write over them, in which case nothing useful was accomplished by copying them the first time.)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558695581



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(

Review comment:
       You're right, this is simpler and more readable.
   
   I probably would have just done it this way initially, if I'd realized gpdb5 didn't have format().  After a series of refactors, it ended up a lot more complicated than it needed to be.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558637008



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests ========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance elsewhere.
+--  They take a long time to run.  Including similar tests here would probably not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the FitMultiple class.
+--  However, most of the important behavior we need to test requires access to an actual
+--  Greenplum database... mostly, we want to make sure that the models hop around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as a part of
+--  dev-check. we mock fit_transition() and some validation functions in FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we can test what
+--  the model hopping behavior will be for a large cluster, even though dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params
+--  are filled in.  They can be overriden later, before test functions are called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(

Review comment:
       Great point!
   
   In MADlib it gets called as `{self.schema_madlib}.fit_transition_multiple_model()`, so I think it should be okay?  Unless some other tests call it without the schema qualification.  But I agree, leaving this function lying around afterwards seems very risky... I'll add a DROP FUNCTION at the end.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r535759055



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        DEBUG.plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
 
-    def create_model_output_table_warm_start(self):
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col}, {self.mst_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = DEBUG.plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        DEBUG.plpy.info("Pre-initialized mst keys: {}".format(initialized_msts))

Review comment:
       We should try to reduce the frequency of `DEBUG.plpy.info` code lines since they can be feature specific. Having a lot of these code lines might make the code slightly harder to read. 
   Any developer working on a feature can add their own `DEBUG.plpy.info` as needed.
   
   I think we can remove this debug plpy.info statement. We can add it back for debugging as and when needed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista edited a comment on pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista edited a comment on pull request #525:
URL: https://github.com/apache/madlib/pull/525#issuecomment-732492520


   Thanks for testing it out Frank!
   
   You're the first to test this on with tensorflow 1.13, and based on one of the errors you're seeing, I'm guessing you're also the first to test it with gpdb5.  I only tested it on gpdb6 and postgres.
   
   > WARNING:  This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.  (seg1 slice1 10.128.0.41:40001 pid=6271)
   > CONTEXT:  PL/Python function "fit_transition"
   > ```
   > 
   > What does user need to do to enable XLA? I am on TF 1.13.1 currently.
   
   Nothing, probably just need to upgrade to TF 1.14.  I should add the specific minimum version requirement in the warning message.  I think it was added in 1.14 but I've only used 1.10 and 1.14, so I'll need to confirm that.  Could also be that it's been in there for a while, and they just changed how you access it.  I'll look both of those up, and also run some more comprehensive performance tests to see how much of a difference it actually makes and in what kind of setups.  If it looks like it's not helping much anyway (even though the TF docs say it should) then we can get rid of this message... or maybe even not bother to enable it.  I figure if it does give better performance consistently across models, then we may as well let the user know there's an easy way to get that, without having to tune any parameters or upgrade hardware.  Would also be important to know if there are any cases where it might hurt performance.
    
   > ERROR:  plpy.Error: madlib_keras_fit_multiple_model error: No GPUs configured on hosts. (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 23, in <module>
   >     fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
   >   PL/Python function "madlib_keras_fit_multiple_model", line 147, in __init__
   >   PL/Python function "madlib_keras_fit_multiple_model", line 295, in get_accessible_gpus_for_seg
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > so it looks like `use gpus=NULL` is now defaulting to `TRUE` but it should default to `FALSE` i.e., CPUs. It used to default to CPUs.
   
   Good catch!  For some reason, I get a different error when I tried to reproduce this ( about using BOOLEAN::None in a query).  But I think I see the cause of it--fixing.  I thought we had some dev-check tests for this, will add one if we don't.
   
   > ```
   > ERROR:  plpy.SPIError: PRIMARY KEY and DISTRIBUTED BY definitions incompatible
   > HINT:  When there is both a PRIMARY KEY, and a DISTRIBUTED BY clause, the DISTRIBUTED BY clause must be equal to or a left-subset of the PRIMARY KEY
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_fit_multiple_model", line 24, in <module>
   >     fit_obj.fit_multiple_model()
   >   PL/Python function "madlib_keras_fit_multiple_model", line 241, in fit_multiple_model
   >   PL/Python function "madlib_keras_fit_multiple_model", line 509, in init_model_output_tbl
   > PL/Python function "madlib_keras_fit_multiple_model"
   > ```
   > 
   > I have also seen this error when running on a single segment.
   
   Interesting!  Looks like this was due to an unnecessarily strict check in the gpdb5 server code, where it required a specific order for the keys in the PRIMARY KEY constraint.  The server team has fixed that in gpdb6:
   https://github.com/greenplum-db/gpdb/commit/e572af54e4255d501d261d51d32e1f3d3cf4b9c9
   
   I can't recall for sure if we decided we care about supporting deep learning in gpdb5 for this release.  But since it's an easy fix, I'll just reverse the order of the keys so that it works on both.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r537883975



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)

Review comment:
       Yes, I figure any time the same query is going to be executed many times in a loop, best practice is to prepare the statement once and then execute it many times... no need to prepare the same statement again each time.
   
   That reminds me... I was hoping to do the same for the Hop and UDA queries... just never got to it.  I did include a TODO comment next to the UDA.  If I get a chance, I'll update those... might make more of a difference there.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558628129



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state

Review comment:
       When # msts < # segs, any segments not currently training a model will be called with model_weights = model_arch = NULL.  Since those segments have nothing to do, we just return early and wait for the other segments to finish training.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] reductionista merged pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
reductionista merged pull request #525:
URL: https://github.com/apache/madlib/pull/525


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r533783387



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)

Review comment:
       Is storing this Prepared plan here mainly for caching the plan for creating the scheduled table ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):

Review comment:
       The attribute should be `rotate_schedule_tbl_plan` instead of `rotate_schedule_plan`, else this if will always be true.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};

Review comment:
       Should also have the `DISTRIBUTED BY` clause same as when we initialize the table in `init_schedule_tbl()`

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       Shouldn't we distribute this by the `prev_dist_key_col` as that is what we use in as the JOIN key when creating the `model_input_tbl`.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -518,111 +567,130 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with K.tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
+
     if is_last_row:
+        SD[SD_STORE.AGG_IMAGE_COUNT] = 0  # Must be reset after each pass through images
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
+
+    SD[SD_STORE.TRANS_EXIT_TIME] = trans_exit_time
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
-                             independent_var_shape, model_architecture,
-                             compile_params, fit_params, dist_key, dist_key_mapping,
-                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
-                             accessible_gpus_for_seg, prev_serialized_weights,
-                             is_final_training_call, custom_function_map=None, **kwargs):
+def fit_multiple_transition_caching(
+    dependent_var, independent_var, dependent_var_shape, independent_var_shape,
+    model_architecture, compile_params, fit_params, dist_key, dist_key_mapping,
+    current_seg_id, segments_per_host, images_per_seg, use_gpus, accessible_gpus_for_seg,
+    serialized_weights, is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
+    - weights can be passed in as None for the

Review comment:
       Let's also specify when this is possible: 
   ```
   weights can be passed in as None only
   in the case where mst < num of seg
   ```

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -75,8 +79,7 @@ segment.
 Note that this function is disabled for Postgres.
 """
 
-@MinWarning("warning")

Review comment:
       Shouldn't we have `@MinWarning("warning")` here. There are a couple of Notices that get printed when creating the next schedule table, and model output summary table 
   
   ```
   NOTICE:  Table doesn't have 'DISTRIBUTED BY' clause -- Using column(s) named 'mst_key' as the Greenplum Database data distribution key for this table.
   HINT:  The 'DISTRIBUTED BY' clause determines the distribution of data. Make sure column(s) chosen are the optimal data distribution key to minimize skew.
   CONTEXT:  SQL statement "
                   CREATE TABLE __madlib_temp_next_schedule210606_1606863675_3469755__ AS
                       SELECT
                           mst_key,
                           model_id,
                           __dist_key__ AS __prev_dist_key__,
                           COALESCE(
                               LEAD(__dist_key__)
                                   OVER(ORDER BY __dist_key__),
                               FIRST_VALUE(__dist_key__)
                                   OVER(ORDER BY __dist_key__)
                           ) AS __dist_key__
                       FROM __madlib_temp_schedule9385959_1606863670_2396302__;
               "
   ```

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -616,8 +780,8 @@ class FitMultipleModel():
             self.update_info_table(mst, True)
             if self.validation_table:
                 self.update_info_table(mst, False)
-
-    def run_training(self, mst_idx, is_very_first_hop):
+    
+    def run_training(self, hop, is_very_first_hop):

Review comment:
       I think it's worth adding a comment explaining the flow in this function, when we refer to it later.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +793,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))
 
-        self.truncate_and_drop_tables()
+        DEBUG.start_timing("truncate_input")
+        self.truncate_and_drop(self.model_input_tbl)
+        DEBUG.print_timing("truncate_input")
+        DEBUG.print_timing("run_training")
 
-    def truncate_and_drop_tables(self):
+    def truncate_and_drop(self, table):
         """
-        Context: UPDATE statements in postgres are not in-place replacements but
-        the row to be updated is marked for deletion(note that the disk space for
-        this row doesn't get released until vaccuum is called) and a new row in
-        inserted.
-
-        This function will clear out the disk space used by the model_output_table
-        and also drop all the other intermediate tables.
-        If available, set the `` guc so that the truncate command can release the
-        disk space. The disk space will be released immediately and hence the
-        model_output table won't grow in size with each UPDATE statement.
+        This function truncates and drops one of the intermediate tables used
+        during an iteration (mostly, model_input_tbl and model_output_tbl).

Review comment:
       Also the schedule table




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

Posted by GitBox <gi...@apache.org>.
kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r560525168



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -663,17 +717,20 @@ def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in updated_model_weights]
-            new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:

Review comment:
       yeah I think we should get rid of it for fit as well but you are right, better to do that in a future PR




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org