This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new ec85ba0 Deep Learning: Use compile_and_set_weights() in predict
ec85ba0 is described below
commit ec85ba0652fb551c1739683ebbc9d6aff84a26b2
Author: Nandish Jayaram <[email protected]>
AuthorDate: Mon Apr 1 14:58:48 2019 -0700
Deep Learning: Use compile_and_set_weights() in predict
Commit SHA 137ba49 changed the way we process compile_params, and
although it was used in fit, it wasn't used in the predict function.
This commit makes necessary changes to process compile_params using this
new function. We also now use KerasWeightsSerializer.get_model_shapes to
get the model shape in predict function.
---
.../modules/deep_learning/madlib_keras_predict.py_in | 14 ++++----------
.../modules/deep_learning/test/madlib_keras.sql_in | 2 +-
2 files changed, 5 insertions(+), 11 deletions(-)
diff --git
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 5e4e62b..c484c09 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -32,6 +32,7 @@ from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
+from madlib_keras_wrapper import compile_and_set_weights
from madlib_keras_wrapper import convert_string_of_args_to_dict
from madlib_keras_helper import get_class_values_and_type
from madlib_keras_helper import KerasWeightsSerializer
@@ -76,19 +77,12 @@ def predict(schema_madlib, model_table, test_table, id_col,
model_arch_table,
def internal_keras_predict(x_test, model_arch, model_data, input_shape,
compile_params, class_values):
model = model_from_json(model_arch)
- compile_params = convert_string_of_args_to_dict(compile_params)
device_name = '/cpu:0'
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+ model_shapes = KerasWeightsSerializer.get_model_shapes(model)
+ compile_and_set_weights(model, compile_params, device_name,
+ model_data, model_shapes)
- with K.tf.device(device_name):
- model.compile(**compile_params)
-
- model_shapes = []
- for weight_arr in model.get_weights():
- model_shapes.append(weight_arr.shape)
- _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
- model_data, model_shapes)
- model.set_weights(model_weights)
x_test = np.array(x_test).reshape(1, *input_shape)
x_test /= 255
proba_argmax = model.predict_classes(x_test)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index f393db3..3e8a18c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -171,5 +171,5 @@ SELECT madlib_keras_predict(
'model_arch',
1,
'x',
- '''optimizer''=SGD(lr=0.01, decay=1e-6, nesterov=True),
''loss''=''categorical_crossentropy'', ''metrics''=[''accuracy'']'::text,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True),
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');