This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new ec85ba0  Deep Learning: Use compile_and_set_weights() in predict
ec85ba0 is described below

commit ec85ba0652fb551c1739683ebbc9d6aff84a26b2
Author: Nandish Jayaram <[email protected]>
AuthorDate: Mon Apr 1 14:58:48 2019 -0700

    Deep Learning: Use compile_and_set_weights() in predict
    
    Commit SHA 137ba49 changed the way we process compile_params, and
    although it was used in fit, it wasn't used in the predict function.
    This commit makes necessary changes to process compile_params using this
    new function. We also now use KerasWeightsSerializer.get_model_shapes to
    get the model shape in predict function.
---
 .../modules/deep_learning/madlib_keras_predict.py_in       | 14 ++++----------
 .../modules/deep_learning/test/madlib_keras.sql_in         |  2 +-
 2 files changed, 5 insertions(+), 11 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 5e4e62b..c484c09 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -32,6 +32,7 @@ from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 
+from madlib_keras_wrapper import compile_and_set_weights
 from madlib_keras_wrapper import convert_string_of_args_to_dict
 from madlib_keras_helper import get_class_values_and_type
 from madlib_keras_helper import KerasWeightsSerializer
@@ -76,19 +77,12 @@ def predict(schema_madlib, model_table, test_table, id_col, 
model_arch_table,
 def internal_keras_predict(x_test, model_arch, model_data, input_shape,
                            compile_params, class_values):
     model = model_from_json(model_arch)
-    compile_params = convert_string_of_args_to_dict(compile_params)
     device_name = '/cpu:0'
     os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+    model_shapes = KerasWeightsSerializer.get_model_shapes(model)
+    compile_and_set_weights(model, compile_params, device_name,
+                            model_data, model_shapes)
 
-    with K.tf.device(device_name):
-        model.compile(**compile_params)
-
-    model_shapes = []
-    for weight_arr in model.get_weights():
-        model_shapes.append(weight_arr.shape)
-    _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
-        model_data, model_shapes)
-    model.set_weights(model_weights)
     x_test = np.array(x_test).reshape(1, *input_shape)
     x_test /= 255
     proba_argmax = model.predict_classes(x_test)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index f393db3..3e8a18c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -171,5 +171,5 @@ SELECT madlib_keras_predict(
     'model_arch',
     1,
     'x',
-    '''optimizer''=SGD(lr=0.01, decay=1e-6, nesterov=True), 
''loss''=''categorical_crossentropy'', ''metrics''=[''accuracy'']'::text,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     'cifar10_predict');

Reply via email to