You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2020/02/21 22:18:11 UTC

[madlib] branch master updated: Fix fit_multiple() warm start

This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new b4ef574  Fix fit_multiple() warm start
b4ef574 is described below

commit b4ef5748753977798ef39a38db51516d8c67d965
Author: Domino Valdano <dv...@pivotal.io>
AuthorDate: Fri Feb 21 00:31:22 2020 +0000

    Fix fit_multiple() warm start
    
    The get_initial_weights() function is shared by both
    fit() and fit_multiple().  But until now the code was
    not updated to support fit_multiple().  When called
    with warm-start=TRUE, this was causing improper
    initialization of weights for small datasets, and
    out-of-memory errors for larger datasets.
---
 src/ports/postgres/modules/deep_learning/madlib_keras.py_in  |  6 +++---
 .../deep_learning/madlib_keras_fit_multiple_model.py_in      | 12 +++++++++++-
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 01f9152..04458c4 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -334,7 +334,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     reset_cuda_env(original_cuda_env)
 
 def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
-                        use_gpus, accessible_gpus_for_seg):
+                        use_gpus, accessible_gpus_for_seg, mst_filter=''):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
@@ -359,8 +359,8 @@ def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
 
     if warm_start:
         serialized_weights = plpy.execute("""
-            SELECT model_weights FROM {0}
-        """.format(model_table))[0]['model_weights']
+            SELECT model_weights FROM {model_table} {mst_filter} LIMIT 1
+        """.format(**locals()))[0]['model_weights']
     else:
         if not serialized_weights:
             model = model_from_json(model_arch)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 97c3c60..1bb77f0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -341,13 +341,23 @@ class FitMultipleModel():
             # used, even if a particular model doesn't have warm start weigths.
             if self.warm_start:
                 model_weights = None
+                mst_filter = """
+                    WHERE {mst_col}={mst_key}
+                """.format(
+                        mst_col=self.mst_key_col,
+                        mst_key=mst['mst_key']
+                    )
+ 
+            else:
+                mst_filter = ''
 
             serialized_weights = get_initial_weights(self.model_output_table,
                                                      model_arch,
                                                      model_weights,
                                                      mst['mst_key'] in warm_start_msts,
                                                      self.use_gpus,
-                                                     self.accessible_gpus_for_seg
+                                                     self.accessible_gpus_for_seg,
+                                                     mst_filter
                                                      )
             model_size = sys.getsizeof(serialized_weights) / 1024.0