You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/02/02 00:47:38 UTC

systemml git commit: [SYSTEMML-445] Added load_keras_weights flag in Keras2DML to avoid transfering randomly initialized weights

Repository: systemml
Updated Branches:
  refs/heads/master ad5275932 -> 416ebc02a


[SYSTEMML-445] Added load_keras_weights flag in Keras2DML to avoid transfering randomly initialized weights

- By default, load_keras_weights is set to False. Hence, the weights will
  be transferred to SystemML by default.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/416ebc02
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/416ebc02
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/416ebc02

Branch: refs/heads/master
Commit: 416ebc02a2a7eddfa2d8e0456003cede7af9fa37
Parents: ad52759
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Feb 1 16:45:25 2018 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Feb 1 16:45:24 2018 -0800

----------------------------------------------------------------------
 src/main/python/systemml/mllearn/estimators.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/416ebc02/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index bbf96c6..3f11d3f 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -896,7 +896,7 @@ class Keras2DML(Caffe2DML):
 
     """
 
-    def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"):
+    def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"):
         """
         Performs training/prediction for a given keras model.
 
@@ -906,6 +906,7 @@ class Keras2DML(Caffe2DML):
         keras_model: keras model
         input_shape: 3-element list (number of channels, input height, input width)
         transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False)
+        load_keras_weights: whether to load weights from the keras_model. If False, the weights will be initialized to random value using NN libraries' init method  (default: True)
         weights: directory whether learned weights are stored (default: None)
         labels: file containing mapping between index and string labels (default: None)
         batch_size: size of the input batch (default: 64)
@@ -931,7 +932,8 @@ class Keras2DML(Caffe2DML):
         convertKerasToCaffeNetwork(keras_model, self.name + ".proto", int(batch_size))
         convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name + "_solver.proto", int(max_iter), int(test_iter), int(test_interval), int(display), lr_policy, weight_decay, regularization_type)
         self.weights = tempfile.mkdtemp() if weights is None else weights
-        convertKerasToSystemMLModel(sparkSession, keras_model, self.weights)
+        if load_keras_weights:
+            convertKerasToSystemMLModel(sparkSession, keras_model, self.weights)
         if labels is not None and (labels.startswith('https:') or labels.startswith('http:')):
             import urllib
             urllib.urlretrieve(labels, os.path.join(weights, 'labels.txt'))
@@ -939,7 +941,8 @@ class Keras2DML(Caffe2DML):
             from shutil import copyfile
             copyfile(labels, os.path.join(weights, 'labels.txt'))
         super(Keras2DML,self).__init__(sparkSession, self.name + "_solver.proto", input_shape, transferUsingDF)
-        self.load(self.weights)
+        if load_keras_weights:
+            self.load(self.weights)
 
     def close(self):
         import shutil