You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/12/03 22:57:34 UTC

[GitHub] [madlib] khannaekta commented on a change in pull request #525: DL: Model Hopper Refactor

khannaekta commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r533783387



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)

Review comment:
       Is storing this Prepared plan here mainly for caching the plan for creating the scheduled table ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):

Review comment:
       The attribute should be `rotate_schedule_tbl_plan` instead of `rotate_schedule_plan`, else this if will always be true.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};

Review comment:
       Should also have the `DISTRIBUTED BY` clause same as when we initialize the table in `init_schedule_tbl()`

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})

Review comment:
       Shouldn't we distribute this by the `prev_dist_key_col` as that is what we use in as the JOIN key when creating the `model_input_tbl`.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -518,111 +567,130 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with K.tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
+
     if is_last_row:
+        SD[SD_STORE.AGG_IMAGE_COUNT] = 0  # Must be reset after each pass through images
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
+
+    SD[SD_STORE.TRANS_EXIT_TIME] = trans_exit_time
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
-                             independent_var_shape, model_architecture,
-                             compile_params, fit_params, dist_key, dist_key_mapping,
-                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
-                             accessible_gpus_for_seg, prev_serialized_weights,
-                             is_final_training_call, custom_function_map=None, **kwargs):
+def fit_multiple_transition_caching(
+    dependent_var, independent_var, dependent_var_shape, independent_var_shape,
+    model_architecture, compile_params, fit_params, dist_key, dist_key_mapping,
+    current_seg_id, segments_per_host, images_per_seg, use_gpus, accessible_gpus_for_seg,
+    serialized_weights, is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
+    - weights can be passed in as None for the

Review comment:
       Let's also specify when this is possible: 
   ```
   weights can be passed in as None only
   in the case where mst < num of seg
   ```

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -75,8 +79,7 @@ segment.
 Note that this function is disabled for Postgres.
 """
 
-@MinWarning("warning")

Review comment:
       Shouldn't we have `@MinWarning("warning")` here. There are a couple of Notices that get printed when creating the next schedule table, and model output summary table 
   
   ```
   NOTICE:  Table doesn't have 'DISTRIBUTED BY' clause -- Using column(s) named 'mst_key' as the Greenplum Database data distribution key for this table.
   HINT:  The 'DISTRIBUTED BY' clause determines the distribution of data. Make sure column(s) chosen are the optimal data distribution key to minimize skew.
   CONTEXT:  SQL statement "
                   CREATE TABLE __madlib_temp_next_schedule210606_1606863675_3469755__ AS
                       SELECT
                           mst_key,
                           model_id,
                           __dist_key__ AS __prev_dist_key__,
                           COALESCE(
                               LEAD(__dist_key__)
                                   OVER(ORDER BY __dist_key__),
                               FIRST_VALUE(__dist_key__)
                                   OVER(ORDER BY __dist_key__)
                           ) AS __dist_key__
                       FROM __madlib_temp_schedule9385959_1606863670_2396302__;
               "
   ```

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -616,8 +780,8 @@ class FitMultipleModel():
             self.update_info_table(mst, True)
             if self.validation_table:
                 self.update_info_table(mst, False)
-
-    def run_training(self, mst_idx, is_very_first_hop):
+    
+    def run_training(self, hop, is_very_first_hop):

Review comment:
       I think it's worth adding a comment explaining the flow in this function, when we refer to it later.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +793,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))
 
-        self.truncate_and_drop_tables()
+        DEBUG.start_timing("truncate_input")
+        self.truncate_and_drop(self.model_input_tbl)
+        DEBUG.print_timing("truncate_input")
+        DEBUG.print_timing("run_training")
 
-    def truncate_and_drop_tables(self):
+    def truncate_and_drop(self, table):
         """
-        Context: UPDATE statements in postgres are not in-place replacements but
-        the row to be updated is marked for deletion(note that the disk space for
-        this row doesn't get released until vaccuum is called) and a new row in
-        inserted.
-
-        This function will clear out the disk space used by the model_output_table
-        and also drop all the other intermediate tables.
-        If available, set the `` guc so that the truncate command can release the
-        disk space. The disk space will be released immediately and hence the
-        model_output table won't grow in size with each UPDATE statement.
+        This function truncates and drops one of the intermediate tables used
+        during an iteration (mostly, model_input_tbl and model_output_tbl).

Review comment:
       Also the schedule table




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org