Repository: systemml Updated Branches: refs/heads/master ad5275932 -> 416ebc02a
[SYSTEMML-445] Added load_keras_weights flag in Keras2DML to avoid transfering randomly initialized weights - By default, load_keras_weights is set to False. Hence, the weights will be transferred to SystemML by default. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/416ebc02 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/416ebc02 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/416ebc02 Branch: refs/heads/master Commit: 416ebc02a2a7eddfa2d8e0456003cede7af9fa37 Parents: ad52759 Author: Niketan Pansare <[email protected]> Authored: Thu Feb 1 16:45:25 2018 -0800 Committer: Niketan Pansare <[email protected]> Committed: Thu Feb 1 16:45:24 2018 -0800 ---------------------------------------------------------------------- src/main/python/systemml/mllearn/estimators.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/416ebc02/src/main/python/systemml/mllearn/estimators.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index bbf96c6..3f11d3f 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -896,7 +896,7 @@ class Keras2DML(Caffe2DML): """ - def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"): + def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"): """ Performs training/prediction for a given keras model. @@ -906,6 +906,7 @@ class Keras2DML(Caffe2DML): keras_model: keras model input_shape: 3-element list (number of channels, input height, input width) transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False) + load_keras_weights: whether to load weights from the keras_model. If False, the weights will be initialized to random value using NN libraries' init method (default: True) weights: directory whether learned weights are stored (default: None) labels: file containing mapping between index and string labels (default: None) batch_size: size of the input batch (default: 64) @@ -931,7 +932,8 @@ class Keras2DML(Caffe2DML): convertKerasToCaffeNetwork(keras_model, self.name + ".proto", int(batch_size)) convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name + "_solver.proto", int(max_iter), int(test_iter), int(test_interval), int(display), lr_policy, weight_decay, regularization_type) self.weights = tempfile.mkdtemp() if weights is None else weights - convertKerasToSystemMLModel(sparkSession, keras_model, self.weights) + if load_keras_weights: + convertKerasToSystemMLModel(sparkSession, keras_model, self.weights) if labels is not None and (labels.startswith('https:') or labels.startswith('http:')): import urllib urllib.urlretrieve(labels, os.path.join(weights, 'labels.txt')) @@ -939,7 +941,8 @@ class Keras2DML(Caffe2DML): from shutil import copyfile copyfile(labels, os.path.join(weights, 'labels.txt')) super(Keras2DML,self).__init__(sparkSession, self.name + "_solver.proto", input_shape, transferUsingDF) - self.load(self.weights) + if load_keras_weights: + self.load(self.weights) def close(self): import shutil
