This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit e41cf6e619e969e2c0b733a25fe41c3f6f7bad45
Author: Domino Valdano <[email protected]>
AuthorDate: Wed May 15 18:02:51 2019 -0700

    DL: Replace old evaluate1 function with madlib_keras_evaluate()
    
    This function can be called by a user on a test image set, to run loss
    and metric evaluation for a particular model. It handles either 0 or 1
    metric in compile params and adds a metric_type column to make the
    output table more readable.
    
    Also:
    
     - Fixes bug in final function (shouldn't return image count)
     - Validates if summary test table exists
     - Get independent_varname and dependent_varname from summary table.
       Refactor several functions related to evaluate, so that these
       parameters are only used to represent the original variable names
       input by the user. In cases where it refers to the fixed strings
       'independent_var' and 'dependent_var' (columns of minibatch output
       table), they have been removed.
       As with fit, they have also been removed from the interface, so the
       user no longer has to pass them in.
     - Supress extra warning output
     - Adds devcheck tests for evaluate()
    
    Closes #395
    
    Co-authored-by: Orhan Kislal <[email protected]>
    Co-authored-by: Ekta Khanna <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       | 186 ++++++++++++---------
 .../modules/deep_learning/madlib_keras.sql_in      |  33 ++--
 .../deep_learning/madlib_keras_helper.py_in        |   1 +
 .../deep_learning/madlib_keras_validator.py_in     |  80 +++++++--
 .../modules/deep_learning/test/madlib_keras.sql_in |  63 ++++++-
 .../test/unit_tests/test_madlib_keras.py_in        |  91 +++++-----
 6 files changed, 284 insertions(+), 170 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 6393dd8..f164ce7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -38,10 +38,13 @@ from keras.regularizers import *
 import madlib_keras_serializer
 from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
+from madlib_keras_helper import INDEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_validator import FitInputValidator
+from madlib_keras_validator import EvaluateInputValidator
 from madlib_keras_wrapper import *
 from keras_model_arch_table import ModelArchSchema
 
@@ -55,23 +58,25 @@ from utilities.utilities import madlib_version
 from utilities.validate_args import get_col_value_and_type
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import quote_ident
+from utilities.control import MinWarning
 
 @MinWarning("warning")
-def fit(schema_madlib, source_table, model,model_arch_table,
+def fit(schema_madlib, source_table, model, model_arch_table,
         model_arch_id, compile_params, fit_params, num_iterations,
-        gpus_per_host = 0, validation_table=None,
+        gpus_per_host=0, validation_table=None,
         metrics_compute_frequency=None, name="",
         description="", **kwargs):
     source_table = quote_ident(source_table)
-    dependent_varname = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
-    independent_varname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
     model_arch_table = quote_ident(model_arch_table)
     fit_params = "" if not fit_params else fit_params
     _assert(compile_params, "Compile parameters cannot be empty or NULL.")
 
+    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+    mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
-        dependent_varname, independent_varname,
+        mb_dep_var_col, mb_indep_var_col,
         num_iterations, metrics_compute_frequency)
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
@@ -80,13 +85,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     metrics_elapsed_start_time = time.time()
     start_training_time = datetime.datetime.now()
 
-    gpus_per_host = 0 if gpus_per_host is None else gpus_per_host
-    segments_per_host = get_segments_per_host()
-
-    if 0 < gpus_per_host < segments_per_host:
-        plpy.warning('The number of gpus per host is less than the number of '
-                     'segments per host. The support for this case is '
-                     'experimental and it may fail.')
+    segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
 
     #TODO add a unit test for this in a future PR
     # save the original value of the env variable so that we can reset it 
later.
@@ -123,10 +122,10 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
         gp_segment_id_col = 'gp_segment_id'
 
     # Compute total images on each segment
-    seg_ids_train, images_per_seg_train = get_images_per_seg(source_table, 
dependent_varname)
+    seg_ids_train, images_per_seg_train = get_images_per_seg(source_table)
 
     if validation_table:
-        seg_ids_val, images_per_seg_val = get_images_per_seg(validation_table, 
dependent_varname)
+        seg_ids_val, images_per_seg_val = get_images_per_seg(validation_table)
 
     # Convert model from json and initialize weights
     master_model = model_from_json(model_arch)
@@ -152,8 +151,8 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     fit_params_to_pass = "$madlib$" + fit_params + "$madlib$"
     run_training_iteration = plpy.prepare("""
         SELECT {schema_madlib}.fit_step(
-            {dependent_varname}::SMALLINT[],
-            {independent_varname}::REAL[],
+            {mb_dep_var_col}::SMALLINT[],
+            {mb_indep_var_col}::REAL[],
             $MAD${model_arch}$MAD$::TEXT,
             {compile_params_to_pass}::TEXT,
             {fit_params_to_pass}::TEXT,
@@ -190,8 +189,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
                                             num_iterations):
             # Compute loss/accuracy for training data.
             compute_loss_and_metrics(
-                schema_madlib, source_table, dependent_varname,
-                independent_varname, compile_params_to_pass, model_arch,
+                schema_madlib, source_table, compile_params_to_pass, 
model_arch,
                 serialized_weights, gpus_per_host, segments_per_host, 
seg_ids_train,
                 images_per_seg_train, training_metrics, training_loss,
                 i, "Training")
@@ -199,11 +197,10 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
             if validation_set_provided:
                 # Compute loss/accuracy for validation data.
                 compute_loss_and_metrics(
-                    schema_madlib, validation_table, dependent_varname,
-                    independent_varname, compile_params_to_pass, model_arch,
-                    serialized_weights, gpus_per_host, segments_per_host, 
seg_ids_val,
-                    images_per_seg_val, validation_metrics, validation_loss,
-                    i, "Validation")
+                    schema_madlib, validation_table, compile_params_to_pass,
+                    model_arch, serialized_weights, gpus_per_host, 
segments_per_host,
+                    seg_ids_val, images_per_seg_val, validation_metrics,
+                    validation_loss, i, "Validation")
             metrics_elapsed_end_time = time.time()
             metrics_elapsed_time.append(
                 metrics_elapsed_end_time-metrics_elapsed_start_time)
@@ -217,8 +214,8 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     norm_const = src_summary_dict['norm_const']
     norm_const_type = src_summary_dict['norm_const_type']
     dep_vartype = src_summary_dict['dep_vartype']
-    dependent_varname_in_source_table = 
src_summary_dict['dependent_varname_in_source_table']
-    independent_varname_in_source_table = 
src_summary_dict['independent_varname_in_source_table']
+    dependent_varname = src_summary_dict['dependent_varname_in_source_table']
+    independent_varname = 
src_summary_dict['independent_varname_in_source_table']
     # Define some constants to be inserted into the summary table.
     model_type = "madlib_keras"
     compile_params_dict = convert_string_of_args_to_dict(compile_params)
@@ -252,8 +249,8 @@ def fit(schema_madlib, source_table, model,model_arch_table,
         SELECT
             $MAD${source_table}$MAD$::TEXT AS source_table,
             $MAD${model}$MAD$::TEXT AS model,
-            $MAD${dependent_varname_in_source_table}$MAD$::TEXT AS 
dependent_varname,
-            $MAD${independent_varname_in_source_table}$MAD$::TEXT AS 
independent_varname,
+            $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
+            $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
             $MAD${model_arch_table}$MAD$::TEXT AS model_arch_table,
             {model_arch_id}::INTEGER AS model_arch_id,
             $1 AS compile_params,
@@ -288,8 +285,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
                    dependent_vartype_colname=DEPENDENT_VARTYPE_COLNAME,
                    normalizing_const_colname=NORMALIZING_CONST_COLNAME,
                    **locals()),
-                   ["TEXT", "TEXT", "TEXT","TEXT", "DOUBLE PRECISION[]",
-                    class_values_type])
+                   ["TEXT", "TEXT", "TEXT", "TEXT", "DOUBLE PRECISION[]", 
class_values_type])
     plpy.execute(create_output_summary_table,
                  [compile_params, fit_params, name,
                   description, metrics_elapsed_time, class_values])
@@ -338,11 +334,9 @@ def get_metrics_sql_string(metrics_list, 
is_metrics_specified):
         metrics_final = metrics_all = 'NULL'
     return metrics_final, metrics_all
 
-def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
-                             independent_varname, compile_params, model_arch,
+def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                              serialized_weights, gpus_per_host, 
segments_per_host,
-                             seg_ids, images_per_seg_val,
-                             metrics_list, loss_list,
+                             seg_ids, images_per_seg_val, metrics_list, 
loss_list,
                              curr_iter, dataset_name):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -351,8 +345,6 @@ def compute_loss_and_metrics(schema_madlib, table, 
dependent_varname,
     start_val = time.time()
     evaluate_result = get_loss_metric_from_keras_eval(schema_madlib,
                                                    table,
-                                                   dependent_varname,
-                                                   independent_varname,
                                                    compile_params,
                                                    model_arch,
                                                    serialized_weights,
@@ -364,9 +356,8 @@ def compute_loss_and_metrics(schema_madlib, table, 
dependent_varname,
     plpy.info("Time for evaluation in iteration {0}: {1} sec.". format(
         curr_iter, end_val - start_val))
     if len(evaluate_result) not in [1, 2]:
-        plpy.error('Calling evaluate on table {0} must return loss '
-                   'and at most one metric value.'.format(
-            table))
+        plpy.error('Calling evaluate on table {0} returned < 2 '
+                   'metrics. Expected both loss and a metric.'.format(table))
     loss = evaluate_result[0]
     metric = evaluate_result[1]
     plpy.info("{0} set metric after iteration {1}: {2}.".
@@ -393,22 +384,24 @@ def should_compute_metrics_this_iter(curr_iter, 
metrics_compute_frequency,
     return (curr_iter)%metrics_compute_frequency == 0 or \
            curr_iter == num_iterations
 
-def get_images_per_seg(source_table, dependent_varname):
+def get_images_per_seg(source_table):
     """
     Compute total images in each segment, by querying source_table.  For
     postgres, this is just the total number of images in the db.
     :param source_table:
-    :param dependent_var:
     :return: Returns a string and two arrays
     1. An array containing all the segment numbers in ascending order
     1. An array containing the total images on each of the segments in the
     segment array.
     """
+
+    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+
     if is_platform_pg():
         res = plpy.execute(
             """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS images_per_seg
                 FROM {1}
-            """.format(dependent_varname, source_table))
+            """.format(mb_dep_var_col, source_table))
         images_per_seg = [int(res[0]['images_per_seg'])]
         seg_ids = [0]
     else:
@@ -416,11 +409,11 @@ def get_images_per_seg(source_table, dependent_varname):
             """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS 
images_per_seg
                 FROM {1}
                 GROUP BY gp_segment_id
-            """.format(dependent_varname, source_table))
+            """.format(mb_dep_var_col, source_table))
         seg_ids = [int(each_segment["gp_segment_id"])
                    for each_segment in images_per_seg]
         images_per_seg = [int(each_segment["images_per_seg"])
-            for each_segment in images_per_seg]
+                          for each_segment in images_per_seg]
     return seg_ids, images_per_seg
 
 def fit_transition(state, dependent_var, independent_var, model_architecture,
@@ -472,12 +465,16 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
         total_images = images_per_seg[seg_ids.index(current_seg_id)]
 
     if total_images == 0:
-        plpy.error('Total images is 0 in fit_transition on segment 
{0}'.format(current_seg_id))
+        if is_platform_pg():
+            plpy.error('Total images is 0 in fit_transition')
+
+        else:
+            plpy.error('Total images is 0 in fit_transition on segment 
{0}'.format(current_seg_id))
 
     # Re-serialize the weights
     # Update image count, check if we are done
     if agg_image_count == total_images:
-       # Once done with all images on a segment, we update weights
+        # Once done with all images on a segment, we update weights
         # with the total number of images here instead of the merge function.
         # The merge function only deals with aggregating them.
         updated_weights = [ total_images * w for w in updated_weights ]
@@ -487,7 +484,7 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
             clear_keras_session()
     elif agg_image_count > total_images:
         plpy.error('Processed {0} images, but there were supposed to be only 
{1}!'
-            .format(agg_image_count, total_images))
+                   .format(agg_image_count, total_images))
 
     new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
         agg_image_count, updated_weights)
@@ -535,46 +532,74 @@ def fit_final(state, **kwargs):
     return madlib_keras_serializer.serialize_state_with_1d_weights(
         image_count, weights)
 
-def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table,
-            model_arch_id, dependent_varname, independent_varname,
-            compile_params, output_table, **kwargs):
-    # module_name = 'madlib_keras_evaluate'
-    # input_tbl_valid(test_table, module_name)
-    # input_tbl_valid(model_arch_table, module_name)
-    # output_tbl_valid(output_table, module_name)
+def get_segments_and_gpus(gpus_per_host):
+    gpus_per_host = 0 if gpus_per_host is None else gpus_per_host
+    segments_per_host = get_segments_per_host()
 
-    # _validate_input_args(test_table, model_arch_table, output_table)
+    if 0 < gpus_per_host < segments_per_host:
+        plpy.warning('The number of gpus per host is less than the number of '
+                     'segments per host. The support for this case is '
+                     'experimental and it may fail.')
 
-    model_data_query = "SELECT model_data from {0}".format(model_table)
-    serialized_weights = plpy.execute(model_data_query)[0]['model_data']
+    return segments_per_host, gpus_per_host
 
-    model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
-                       "WHERE id = {1}".format(model_arch_table, model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("no model arch found in table {0} with id {1}".format(
-            model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result[ModelArchSchema.MODEL_ARCH]
-    compile_params = "$madlib$" + compile_params + "$madlib$"
+def evaluate(schema_madlib, model_table, test_table, output_table, 
gpus_per_host, **kwargs):
+    module_name = 'madlib_keras_evaluate'
+    input_validator = EvaluateInputValidator(test_table, model_table, 
output_table, module_name)
 
-    loss_metric = get_loss_metric_from_keras_eval(
-                    schema_madlib, test_table, dependent_varname,
-                    independent_varname, compile_params, model_arch,
-                    serialized_weights, False, None)
+    model_summary_table = input_validator.model_summary_table
+    test_summary_table = input_validator.test_summary_table
+
+    segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
+
+    model_data_query = "SELECT model_data, model_arch from 
{0}".format(model_table)
+    res = plpy.execute(model_data_query)[0]
+    model_data = res['model_data']
+    model_arch = res['model_arch']
+
+    input_shape = get_input_shape(model_arch)
+    input_validator.validate_input_shape(input_shape)
 
-    #TODO remove these infos after adding create table command
-    plpy.info('len of evaluate result is {}'.format(len(loss_metric)))
-    plpy.info('evaluate result loss is {}'.format(loss_metric[0]))
-    plpy.info('evaluate result metric is {}'.format(loss_metric[1]))
+    compile_params_query = "SELECT compile_params, metrics_type FROM 
{0}".format(model_summary_table)
+    res = plpy.execute(compile_params_query)[0]
+    metrics_type = res['metrics_type']
+    compile_params = "$madlib$" + res['compile_params'] + "$madlib$"
 
-def get_loss_metric_from_keras_eval(schema_madlib, table, dependent_varname,
-                                 independent_varname, compile_params,
-                                 model_arch, serialized_weights, gpus_per_host,
-                                 segments_per_host, seg_ids, images_per_seg):
+    seg_ids, images_per_seg = get_images_per_seg(test_table)
+
+    res = plpy.execute("""
+        SELECT {dependent_varname_col}, {independent_varname_col}
+            FROM {test_summary_table}
+        """.format(dependent_varname_col=DEPENDENT_VARNAME_COLNAME,
+                   independent_varname_col=INDEPENDENT_VARNAME_COLNAME,
+                   test_summary_table=test_summary_table))
+
+    dependent_varname = res[0][DEPENDENT_VARNAME_COLNAME]
+    independent_varname = res[0][INDEPENDENT_VARNAME_COLNAME]
+
+    loss, metric =\
+        get_loss_metric_from_keras_eval(schema_madlib, test_table, 
compile_params, model_arch,
+                                        model_data, gpus_per_host, 
segments_per_host,
+                                        seg_ids, images_per_seg)
+
+    if not metrics_type:
+        metrics_type = None
+        metric = None
+
+    with MinWarning("error"):
+        create_output_table = plpy.prepare("""
+            CREATE TABLE {0} AS
+            SELECT $1 as loss, $2 as metric, $3 as 
metrics_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]"])
+        plpy.execute(create_output_table, [loss, metric, metrics_type])
+
+def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
+                                    model_arch, serialized_weights, 
gpus_per_host,
+                                    segments_per_host, seg_ids, 
images_per_seg):
 
     gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
 
+    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+    mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
     """
     This function will call the internal keras evaluate function to get the 
loss
     and accuracy of each tuple which then gets averaged to get the final 
result.
@@ -585,8 +610,8 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, 
dependent_varname,
     --  SMALLINT to INTEGER, or change the output of minibatch util to produce 
SMALLINT
     --  For the first, we should change fit_step also
     select ({schema_madlib}.internal_keras_evaluate(
-                                            {dependent_varname}::SMALLINT[],
-                                            {independent_varname}::REAL[],
+                                            {mb_dep_var_col}::SMALLINT[],
+                                            {mb_indep_var_col}::REAL[],
                                             $MAD${model_arch}$MAD$,
                                             $1,
                                             {compile_params},
@@ -689,5 +714,4 @@ def internal_keras_eval_final(state, **kwargs):
     loss /= image_count
     metric /= image_count
 
-    state = loss, metric
-    return state
+    return loss, metric
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index f18c63d..6b7b0c0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -253,32 +253,25 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_predict(
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
-CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate1(
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
     model_table             VARCHAR,
     test_table              VARCHAR,
-    id_col                  VARCHAR,
-    model_arch_table        VARCHAR,
-    model_arch_id           INTEGER,
-    dependent_varname       VARCHAR,
-    independent_varname     VARCHAR,
-    compile_params          VARCHAR,
-    output_table            VARCHAR
+    output_table            VARCHAR,
+    gpus_per_host           INTEGER
 ) RETURNS VOID AS $$
-    PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
-    with AOControl(False):
-        madlib_keras.evaluate1(schema_madlib,
-               model_table,
-               test_table,
-               id_col,
-               model_arch_table,
-               model_arch_id,
-               dependent_varname,
-               independent_varname,
-               compile_params,
-               output_table)
+    PythonFunction(`deep_learning', `madlib_keras', `evaluate')
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
+    model_table             VARCHAR,
+    test_table              VARCHAR,
+    output_table            VARCHAR
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.madlib_keras_evaluate($1, $2, $3, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_eval_transition(
     state                              REAL[3],
     dependent_var                      SMALLINT[],
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index c83f2f0..03a8399 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -72,6 +72,7 @@ INDEPENDENT_VARNAME_COLNAME = "independent_varname"
 MODEL_ARCH_TABLE_COLNAME = "model_arch_table"
 MODEL_ARCH_ID_COLNAME = "model_arch_id"
 MODEL_DATA_COLNAME = "model_data"
+METRIC_TYPE_COLNAME = "metrics_type"
 
 # Name of independent and dependent colnames in batched table.
 # These are readonly variables, do not modify.
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 5892308..e344c05 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -23,10 +23,14 @@ from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
+from madlib_keras_helper import INDEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
 from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
 from madlib_keras_helper import MODEL_DATA_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
+from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+from madlib_keras_helper import METRIC_TYPE_COLNAME
 
 from utilities.minibatch_validation import validate_dependent_var_for_minibatch
 from utilities.utilities import _assert
@@ -82,15 +86,13 @@ def _validate_input_shapes(table, independent_varname, 
input_shape, offset):
                     input_shape, input_shape_from_table,
                     independent_varname, table))
 
-class PredictInputValidator:
-    def __init__(self, test_table, model_table, id_col, independent_varname,
-                 output_table, pred_type, module_name):
+class InputValidator:
+    def __init__(self, test_table, model_table, independent_varname,
+                 output_table, module_name):
         self.test_table = test_table
         self.model_table = model_table
-        self.id_col = id_col
         self.independent_varname = independent_varname
         self.output_table = output_table
-        self.pred_type = pred_type
         if self.model_table:
             self.model_summary_table = add_postfix(
                 self.model_table, "_summary")
@@ -99,14 +101,15 @@ class PredictInputValidator:
 
     def _validate_input_args(self):
         input_tbl_valid(self.model_table, self.module_name)
-        self._validate_model_data_col()
+        self._validate_model_data_cols()
         input_tbl_valid(self.model_summary_table, self.module_name)
-        self._validate_summary_tbl_cols()
+        self._validate_model_summary_tbl_cols()
         input_tbl_valid(self.test_table, self.module_name)
         self._validate_test_tbl_cols()
         output_tbl_valid(self.output_table, self.module_name)
 
-    def _validate_model_data_col(self):
+
+    def _validate_model_data_cols(self):
         _assert(is_var_valid(self.model_table, MODEL_DATA_COLNAME),
                 "{module_name} error: column '{model_data}' "
                 "does not exist in model table '{table}'.".format(
@@ -129,14 +132,7 @@ class PredictInputValidator:
                     independent_varname=self.independent_varname,
                     table=self.test_table))
 
-        _assert(is_var_valid(self.test_table, self.id_col),
-                "{module_name} error: invalid id column "
-                "('{id_col}') for test table ({table}).".format(
-                    module_name=self.module_name,
-                    id_col=self.id_col,
-                    table=self.test_table))
-
-    def _validate_summary_tbl_cols(self):
+    def _validate_model_summary_tbl_cols(self):
         cols_to_check_for = [CLASS_VALUES_COLNAME,
                              DEPENDENT_VARNAME_COLNAME,
                              DEPENDENT_VARTYPE_COLNAME,
@@ -149,6 +145,49 @@ class PredictInputValidator:
             "summary table ('{1}'). The expected columns are {2}.".format(
                 self.module_name, self.model_summary_table, cols_to_check_for))
 
+class EvaluateInputValidator(InputValidator):
+    def __init__(self, test_table, model_table, output_table, module_name):
+        self.test_summary_table = None
+        if test_table:
+            self.test_summary_table = add_postfix(test_table, "_summary")
+
+        self.independent_varname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+        InputValidator.__init__(self, test_table, model_table,
+                                self.independent_varname,
+                                output_table, module_name)
+
+    def _validate_input_args(self):
+        input_tbl_valid(self.test_summary_table, self.module_name)
+        self._validate_test_summary_tbl_cols()
+        InputValidator._validate_input_args(self)
+        validate_dependent_var_for_minibatch(self.test_table,
+                                             
MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
+
+    def _validate_model_summary_tbl_cols(self):
+        cols_to_check_for = [COMPILE_PARAMS_COLNAME, METRIC_TYPE_COLNAME]
+        _assert(columns_exist_in_table(
+            self.model_summary_table, cols_to_check_for),
+            "{0} error: One or more expected columns missing in model "
+            "summary table ('{1}'). The expected columns are {2}.".format(
+                self.module_name, self.model_summary_table, cols_to_check_for))
+
+    def _validate_test_summary_tbl_cols(self):
+        cols_in_tbl_valid(self.test_summary_table, [CLASS_VALUES_COLNAME,
+            NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+            DEPENDENT_VARNAME_COLNAME, INDEPENDENT_VARNAME_COLNAME], 
self.module_name)
+
+    def validate_input_shape(self, input_shape_from_arch):
+        _validate_input_shapes(self.test_table, self.independent_varname,
+                               input_shape_from_arch, 2)
+
+class PredictInputValidator(InputValidator):
+    def __init__(self, test_table, model_table, id_col, independent_varname,
+                 output_table, pred_type, module_name):
+        self.id_col = id_col
+        self.pred_type = pred_type
+        InputValidator.__init__(self, test_table, model_table, 
independent_varname,
+                               output_table, module_name)
+
     def validate_pred_type(self, class_values):
         if not self.pred_type in ['prob', 'response']:
             plpy.error("{0}: Invalid value for pred_type param ({1}). Must be 
"\
@@ -162,6 +201,15 @@ class PredictInputValidator:
         _validate_input_shapes(self.test_table, self.independent_varname,
                                input_shape_from_arch, 1)
 
+    def _validate_test_tbl_cols(self):
+        InputValidator._validate_test_tbl_cols(self)
+        _assert(is_var_valid(self.test_table, self.id_col),
+                "{module_name} error: invalid id column "
+                "('{id_col}') for test table ({table}).".format(
+                    module_name=self.module_name,
+                    id_col=self.id_col,
+                    table=self.test_table))
+
 class FitInputValidator:
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, dependent_varname, independent_varname,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 6e6065e..847418f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -25,12 +25,7 @@ copy cifar_10_sample from stdin delimiter '|';
 
2|1|'dog'|'0/img2.jpg'|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119
 [...]
 \.
 
-drop table if exists cifar_10_sample_val;
-create table cifar_10_sample_val(independent_var REAL[], dependent_var 
INTEGER[], buffer_id SMALLINT);
-copy cifar_10_sample_val from stdin delimiter '|';
-{{{{0.494118,0.462745,0.431373},{0.478431,0.45098,0.423529},{0.494118,0.466667,0.435294},{0.498039,0.466667,0.427451},{0.509804,0.478431,0.435294},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.521569,0.490196,0.447059},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.52549,0.494118,0.45098},{0.513726,0.482353,0.439216},{0.513726,0.482353,0.439216},{0.52549,0.494118,0.45098},{0.521569,0.490196,0.447059},{0.533333,0.501961,0.458824},{0.537255,0.505882,0.462745},{
 [...]
-{{{{0.792157,0.8,0.780392},{0.792157,0.8,0.780392},{0.8,0.807843,0.788235},{0.807843,0.815686,0.796079},{0.815686,0.823529,0.803922},{0.819608,0.827451,0.807843},{0.823529,0.831373,0.811765},{0.831373,0.839216,0.823529},{0.835294,0.843137,0.831373},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.839216},{0.85098,0.858824,0.839216},{0.85098,0.858
 [...]
-\.
+
 -- normalize the indep variable
 -- TODO Calling this function makes keras.fit fail with the exception 
(investigate later)
 -- NOTICE:  Releasing segworker groups to finish aborting the transaction.
@@ -72,6 +67,17 @@ INSERT INTO cifar_10_sample_batched_summary values (
     1,
     255.0);
 
+drop table if exists cifar_10_sample_val;
+create table cifar_10_sample_val(independent_var REAL[], dependent_var 
INTEGER[], buffer_id SMALLINT);
+copy cifar_10_sample_val from stdin delimiter '|';
+{{{{0.494118,0.462745,0.431373},{0.478431,0.45098,0.423529},{0.494118,0.466667,0.435294},{0.498039,0.466667,0.427451},{0.509804,0.478431,0.435294},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.521569,0.490196,0.447059},{0.509804,0.478431,0.435294},{0.517647,0.486275,0.443137},{0.52549,0.494118,0.45098},{0.513726,0.482353,0.439216},{0.513726,0.482353,0.439216},{0.52549,0.494118,0.45098},{0.521569,0.490196,0.447059},{0.533333,0.501961,0.458824},{0.537255,0.505882,0.462745},{
 [...]
+{{{{0.792157,0.8,0.780392},{0.792157,0.8,0.780392},{0.8,0.807843,0.788235},{0.807843,0.815686,0.796079},{0.815686,0.823529,0.803922},{0.819608,0.827451,0.807843},{0.823529,0.831373,0.811765},{0.831373,0.839216,0.823529},{0.835294,0.843137,0.831373},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.839216},{0.85098,0.858824,0.839216},{0.85098,0.858
 [...]
+\.
+
+DROP TABLE IF EXISTS cifar_10_sample_val_summary;
+CREATE TABLE cifar_10_sample_val_summary AS
+       SELECT * FROM cifar_10_sample_batched_summary;
+
 --- NOTE:  In order to test fit_merge, we need at least 2 rows in the batched 
table (1 on each segment).
 --- ALSO NOTE: As part of supporting Postgres, an issue was reported JIRA 
MADLIB-1326.
 --- Once this bug is fixed, we should uncomment these 2 lines, which was used 
to generate
@@ -157,6 +163,29 @@ SELECT assert(
         model_arch IS NOT NULL, 'Keras model output validation failed. 
Actual:' || __to_char(k))
 FROM (SELECT * FROM keras_saved_out) k;
 
+-- Test that evaluate works as expected:
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
+
+SELECT assert(loss IS NOT NULL AND
+        metric IS NOT NULL AND
+        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test that passing NULL / None instead of 0 for gpus_per_host works
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
+SELECT assert(loss IS NOT NULL AND
+        metric IS NOT NULL AND
+        metrics_type = '{mae}', 'Evaluate output validation failed.  Actual:' 
|| __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test that evaluate errors out correctly if model_arch field missing from 
fit output
+DROP TABLE IF EXISTS evaluate_out;
+ALTER TABLE keras_saved_out DROP COLUMN model_arch;
+SELECT assert(trap_error($TRAP$
+       SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out');
+       $TRAP$) = 1, 'Should error out if model_arch column is missing from 
model_table');
 
 -- Verify number of iterations for which metrics and loss are computed
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
@@ -276,7 +305,7 @@ FROM (SELECT * FROM keras_out_summary) summary;
 
 SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
 
--- Validate metrics=NULL works fine
+-- Validate metrics=NULL works with fit
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
 'cifar_10_sample_batched',
@@ -299,7 +328,16 @@ SELECT assert(
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM keras_saved_out_summary) summary;
 
--- Validate metrics=[] works fine
+-- Validate that metrics=NULL works with evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
+
+SELECT assert(loss IS NOT NULL AND
+        metric IS NULL AND
+        metrics_type IS NULL, 'Evaluate output validation for NULL metric 
failed.  Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Validate metrics=[] works with fit
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
 'cifar_10_sample_batched',
@@ -322,6 +360,15 @@ SELECT assert(
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM keras_saved_out_summary) summary;
 
+-- Validate metrics=[] works with evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 
'evaluate_out', 0);
+
+SELECT assert(loss IS NOT NULL AND
+        metric IS NULL AND
+        metrics_type IS NULL, 'Evaluate output validation for [] metric 
failed.  Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT madlib_keras_predict(
     'keras_saved_out',
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 286798f..20546a1 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -258,7 +258,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
                 None, self.dependent_var, self.independent_var , 
self.model.to_json(),
                 self.compile_params, self.fit_params, 0, self.all_seg_ids,
                 total_images_per_seg, 0, 4, previous_state.tostring(), **k)
-        self.assertIn('0 rows', str(error.exception))
+        self.assertIn('Total images is 0', str(error.exception))
 
     def test_fit_transition_too_many_images(self):
         self.subject.K.set_session = Mock()
@@ -819,7 +819,7 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         self.assertIn('invalid optimizer', str(error.exception))
 
 
-class MadlibKerasValidatorTestCase(unittest.TestCase):
+class MadlibKerasFitInputValidatorTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -888,6 +888,49 @@ class MadlibKerasValidatorTestCase(unittest.TestCase):
             'dep_varname', 'independent_varname', 5, 6)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
+class PredictInputValidatorTestCases(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import madlib_keras_validator
+        self.module = madlib_keras_validator
+        self.module.PredictInputValidator._validate_input_args = Mock()
+        self.subject = self.module.PredictInputValidator(
+            'test_table', 'model_table', 'id_col', 'independent_varname',
+            'output_table', 'pred_type', 'module_name')
+        self.classes = ['train', 'boat', 'car', 'airplane']
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_validate_pred_type_invalid_pred_type(self):
+        self.subject.pred_type = 'invalid'
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.validate_pred_type(['cat', 'dog'])
+
+    def test_validate_pred_type_valid_pred_type_invalid_num_class_values(self):
+        self.subject.pred_type = 'prob'
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.validate_pred_type(range(1599))
+
+    def test_validate_pred_type_valid_pred_type_valid_class_values_prob(self):
+        self.subject.pred_type = 'prob'
+        self.subject.validate_pred_type(range(1598))
+        self.subject.validate_pred_type(None)
+
+    def 
test_validate_pred_type_valid_pred_type_valid_class_values_response(self):
+        self.subject.pred_type = 'response'
+        self.subject.validate_pred_type(range(1598))
+        self.subject.validate_pred_type(None)
+
 class MadlibSerializerTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -967,49 +1010,6 @@ class MadlibSerializerTestCase(unittest.TestCase):
         self.assertEqual(np.array([0,1,3,4,5], dtype=np.float32).tostring(),
                          res)
 
-class PredictInputPredTypeValidationTestCase(unittest.TestCase):
-    def setUp(self):
-        self.plpy_mock = Mock(spec='error')
-        patches = {
-            'plpy': plpy
-        }
-
-        self.plpy_mock_execute = MagicMock()
-        plpy.execute = self.plpy_mock_execute
-
-        self.module_patcher = patch.dict('sys.modules', patches)
-        self.module_patcher.start()
-        import madlib_keras_validator
-        self.module = madlib_keras_validator
-        self.module.PredictInputValidator._validate_input_args = Mock()
-        self.subject = self.module.PredictInputValidator(
-            'test_table', 'model_table', 'id_col', 'independent_varname',
-            'output_table', 'pred_type', 'module_name')
-        self.classes = ['train', 'boat', 'car', 'airplane']
-
-    def tearDown(self):
-        self.module_patcher.stop()
-
-    def test_validate_pred_type_invalid_pred_type(self):
-        self.subject.pred_type = 'invalid'
-        with self.assertRaises(plpy.PLPYException):
-            self.subject.validate_pred_type(['cat', 'dog'])
-
-    def test_validate_pred_type_valid_pred_type_invalid_num_class_values(self):
-        self.subject.pred_type = 'prob'
-        with self.assertRaises(plpy.PLPYException):
-            self.subject.validate_pred_type(range(1599))
-
-    def test_validate_pred_type_valid_pred_type_valid_class_values_prob(self):
-        self.subject.pred_type = 'prob'
-        self.subject.validate_pred_type(range(1598))
-        self.subject.validate_pred_type(None)
-
-    def 
test_validate_pred_type_valid_pred_type_valid_class_values_response(self):
-        self.subject.pred_type = 'response'
-        self.subject.validate_pred_type(range(1598))
-        self.subject.validate_pred_type(None)
-
 class MadlibKerasHelperTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -1255,6 +1255,7 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         input_state = [image_count*self.loss, image_count*self.accuracy, 
image_count]
 
         output_state = self.subject.internal_keras_eval_final(input_state)
+        self.assertEqual(len(output_state), 2)
         agg_loss = output_state[0]
         agg_accuracy = output_state[1]
 

Reply via email to