This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 27ddd279dc6f3ac4e59a3dc205716f177b5479a5
Author: Nandish Jayaram <[email protected]>
AuthorDate: Thu Apr 25 14:23:08 2019 -0700

    DL: Handle NULL value for optional pred_type param in predict
    
    The pred_type param in predict is an optional param, so one can use NULL
    for the same. This commit sets NULL to 'response' as default value,
    instead of erroring out.
---
 src/ports/postgres/modules/deep_learning/madlib_keras.sql_in        | 2 +-
 src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 5f53488..543bbed 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -208,7 +208,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_predict(
     independent_varname     VARCHAR,
     output_table            VARCHAR
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, 'response', 
TRUE);
+    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, NULL, TRUE);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index d47f53a..4e2a206 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -82,6 +82,8 @@ def _strip_trailing_nulls_from_class_values(class_values):
 
 def predict(schema_madlib, model_table, test_table, id_col,
             independent_varname, output_table, pred_type, use_gpu, **kwargs):
+    if not pred_type:
+        pred_type = 'response'
     input_validator = PredictInputValidator(
         test_table, model_table, id_col, independent_varname,
         output_table, pred_type, MODULE_NAME)

Reply via email to