You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/02/02 00:47:38 UTC
systemml git commit: [SYSTEMML-445] Added load_keras_weights flag in
Keras2DML to avoid transfering randomly initialized weights
Repository: systemml
Updated Branches:
refs/heads/master ad5275932 -> 416ebc02a
[SYSTEMML-445] Added load_keras_weights flag in Keras2DML to avoid transfering randomly initialized weights
- By default, load_keras_weights is set to False. Hence, the weights will
be transferred to SystemML by default.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/416ebc02
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/416ebc02
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/416ebc02
Branch: refs/heads/master
Commit: 416ebc02a2a7eddfa2d8e0456003cede7af9fa37
Parents: ad52759
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Feb 1 16:45:25 2018 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Feb 1 16:45:24 2018 -0800
----------------------------------------------------------------------
src/main/python/systemml/mllearn/estimators.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/416ebc02/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index bbf96c6..3f11d3f 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -896,7 +896,7 @@ class Keras2DML(Caffe2DML):
"""
- def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"):
+ def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"):
"""
Performs training/prediction for a given keras model.
@@ -906,6 +906,7 @@ class Keras2DML(Caffe2DML):
keras_model: keras model
input_shape: 3-element list (number of channels, input height, input width)
transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False)
+ load_keras_weights: whether to load weights from the keras_model. If False, the weights will be initialized to random value using NN libraries' init method (default: True)
weights: directory whether learned weights are stored (default: None)
labels: file containing mapping between index and string labels (default: None)
batch_size: size of the input batch (default: 64)
@@ -931,7 +932,8 @@ class Keras2DML(Caffe2DML):
convertKerasToCaffeNetwork(keras_model, self.name + ".proto", int(batch_size))
convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name + "_solver.proto", int(max_iter), int(test_iter), int(test_interval), int(display), lr_policy, weight_decay, regularization_type)
self.weights = tempfile.mkdtemp() if weights is None else weights
- convertKerasToSystemMLModel(sparkSession, keras_model, self.weights)
+ if load_keras_weights:
+ convertKerasToSystemMLModel(sparkSession, keras_model, self.weights)
if labels is not None and (labels.startswith('https:') or labels.startswith('http:')):
import urllib
urllib.urlretrieve(labels, os.path.join(weights, 'labels.txt'))
@@ -939,7 +941,8 @@ class Keras2DML(Caffe2DML):
from shutil import copyfile
copyfile(labels, os.path.join(weights, 'labels.txt'))
super(Keras2DML,self).__init__(sparkSession, self.name + "_solver.proto", input_shape, transferUsingDF)
- self.load(self.weights)
+ if load_keras_weights:
+ self.load(self.weights)
def close(self):
import shutil