You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/03/17 19:26:04 UTC

[GitHub] [madlib] khannaekta commented on a change in pull request #490: DL: Don't include weights as part of state except for the last row.

khannaekta commented on a change in pull request #490: DL: Don't include weights as part of state except for the last row.
URL: https://github.com/apache/madlib/pull/490#discussion_r393872808
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
 ##########
 @@ -507,51 +509,50 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
     segment_model.fit(x_train, y_train, **fit_params)
-    updated_model_weights = segment_model.get_weights()
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
+    return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
+                                  agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
-    return get_state_to_return(is_last_row, is_multiple_model, agg_image_count,
-                               total_images, updated_model_weights)
+    return return_state
 
-def get_state_to_return(is_last_row, is_multiple_model, agg_image_count,
-                        total_images, updated_model_weights):
+def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
+                        total_images):
     """
-    1. For model averaging fit_transition, the state always contains the image count
-    as well as the model weights
-    2. For fit multiple transition,
-        a. The state that gets passed from one row/buffer (within the same hop)
-        to the next needs to have the image_count and model weights. image_count
-        is needed to keep track of the last image for that hop.
-        b. Once we get to the last row, the state only needs the model
-        weights. This state is the output of the UDA for that hop. We don't need
-        the image_count here because unlike model averaging, model hopper does
-        not have a merge function and there is no need to average the weights
-        based on the image count.
+    1. For both model averaging fit_transition and fit multiple transition, the state
+    only needs to have the image count except for the last row.
+    1. For model averaging fit_transition, the last row state must always contains the
+    image count as well as the model weights
+    2. For fit multiple transition, the last row state only needs the model
+    weights. This state is the output of the UDA for that hop. We don't need
+    the image_count here because unlike model averaging, model hopper does
+    not have a merge/final function and there is no need to average the weights
+    based on the image count.
+    :param segment_model: cached model for that segment
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param updated_model_weights: updated weights after learning (calling keras.fit)
+    :param total_images: total images per segment
     :return:
     """
     if is_last_row:
+        updated_model_weights = segment_model.get_weights()
         if is_multiple_model:
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
             updated_model_weights = [total_images * w for w in updated_model_weights]
             new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
                 agg_image_count, updated_model_weights)
     else:
-        new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-            agg_image_count, updated_model_weights)
+        new_state = float(agg_image_count)
 
 Review comment:
   Why do we cast the `agg_image_count` to float?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services