This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit a5692219a361fad8b8b09c4c6789dd9830a04e8f
Author: Ekta Khanna <[email protected]>
AuthorDate: Wed May 1 17:09:12 2019 -0700

    DL: Add new param metrics_compute_frequency to madlib_keras_fit()
    
    JIRA: MADLIB-1335
    This commit adds a new optional parameter `metrics_compute_frequency` to
    `madlib_keras_fit()` to control how often Keras.evaluate() is run on
    training and validation data for computing loss/metrics. We only support
    one metric at the moment.
    
    Closes #388
    Co-authored-by: Nandish Jayaram <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       | 320 +++++++++++++--------
 .../modules/deep_learning/madlib_keras.sql_in      |  40 ++-
 .../deep_learning/madlib_keras_validator.py_in     |  32 ++-
 .../modules/deep_learning/test/madlib_keras.sql_in |  63 ++--
 4 files changed, 291 insertions(+), 164 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index d0dde1d..9cda792 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -51,11 +51,34 @@ from utilities.utilities import is_platform_pg
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import madlib_version
 from utilities.validate_args import get_col_value_and_type
+from utilities.validate_args import get_expr_type
 from utilities.validate_args import quote_ident
 
+def get_source_summary_table_dict(fit_validator):
+    source_summary = plpy.execute("""
+            SELECT
+                {class_values} AS class_values,
+                {norm_const} AS norm_const,
+                {dep_vartype} AS dep_vartype,
+                {dep_varname} AS dependent_varname_in_source_table,
+                {indep_varname} AS independent_varname_in_source_table
+            FROM {tbl}
+        """.format(class_values=CLASS_VALUES_COLNAME,
+                   norm_const=NORMALIZING_CONST_COLNAME,
+                   dep_vartype=DEPENDENT_VARTYPE_COLNAME,
+                   dep_varname='dependent_varname',
+                   indep_varname='independent_varname',
+                   tbl=fit_validator.source_summary_table))[0]
+    source_summary['class_values_type'] = get_expr_type(
+        CLASS_VALUES_COLNAME, fit_validator.source_summary_table)
+    source_summary['norm_const_type'] = get_expr_type(
+        NORMALIZING_CONST_COLNAME, fit_validator.source_summary_table)
+    return source_summary
+
 def fit(schema_madlib, source_table, model,model_arch_table,
         model_arch_id, compile_params, fit_params, num_iterations,
-        gpus_per_host = 0, validation_table=None, name="",
+        gpus_per_host = 0, validation_table=None,
+        metrics_compute_frequency=None, name="",
         description="", **kwargs):
     source_table = quote_ident(source_table)
     dependent_varname = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
@@ -65,7 +88,10 @@ def fit(schema_madlib, source_table, model,model_arch_table,
 
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
-        dependent_varname, independent_varname, num_iterations)
+        dependent_varname, independent_varname, num_iterations,
+        metrics_compute_frequency)
+    if metrics_compute_frequency is None:
+        metrics_compute_frequency = num_iterations
 
     start_training_time = datetime.datetime.now()
 
@@ -135,7 +161,7 @@ def fit(schema_madlib, source_table, model,model_arch_table,
 
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
-    validation_aggregate_accuracy = []; validation_aggregate_loss = []
+    validation_metrics = []; validation_loss = []
 
     # Prepare the SQL for running distributed training via UDA
     compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
@@ -161,143 +187,138 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
     # Define the state for the model and loss/accuracy storage lists
     model_state = madlib_keras_serializer.serialize_weights(
         0, 0, 0, model_weights)
-    aggregate_loss, aggregate_accuracy, aggregate_runtime = [], [], []
+    training_loss, training_metrics, aggregate_runtime = [], [], []
+    metrics_iters = []
 
     plpy.info("Model architecture size: {}KB".format(len(model_arch)/1024))
     plpy.info("Model state (serialized) size: {}MB".format(
         len(model_state)/1024/1024))
 
     # Run distributed training for specified number of iterations
-    for i in range(num_iterations):
+    for i in range(1, num_iterations+1):
         start_iteration = time.time()
-        iteration_result = plpy.execute(run_training_iteration, 
[model_state])[0]['iteration_result']
+        iteration_result = plpy.execute(run_training_iteration,
+                                        [model_state])[0]['iteration_result']
         end_iteration = time.time()
         plpy.info("Time for iteration {0}: {1} sec".
-                  format(i + 1, end_iteration - start_iteration))
+                  format(i, end_iteration - start_iteration))
         aggregate_runtime.append(datetime.datetime.now())
-        avg_loss, avg_accuracy, model_state = 
madlib_keras_serializer.deserialize_iteration_state(iteration_result)
+        avg_loss, avg_metric, model_state = madlib_keras_serializer.\
+            deserialize_iteration_state(iteration_result)
         plpy.info("Average loss after training iteration {0}: {1}".format(
-            i + 1, avg_loss))
+            i, avg_loss))
         plpy.info("Average accuracy after training iteration {0}: {1}".format(
-            i + 1, avg_accuracy))
-        if validation_set_provided:
-            _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
-                model_state, model_shapes)
-            master_model.set_weights(updated_weights)
-            start_val = time.time()
-            evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
-                                                           validation_table,
-                                                           dependent_varname,
-                                                           independent_varname,
-                                                           
compile_params_to_pass,
-                                                           model_arch, 
model_state,
-                                                           gpus_per_host,
-                                                           segments_per_host,
-                                                           seg_ids_val,
-                                                           images_per_seg_val,
-                                                           gp_segment_id_col)
-            end_val = time.time()
-            plpy.info("Time for validation in iteration {0}: {1} sec". 
format(i + 1, end_val - start_val))
-            if len(evaluate_result) < 2:
-                plpy.error('Calling evaluate on validation data returned < 2 '
-                           'metrics. Expected metrics are loss and accuracy')
-            validation_loss = evaluate_result[0]
-            validation_accuracy = evaluate_result[1]
-            plpy.info("Validation set accuracy after iteration {0}: {1}".
-                      format(i + 1, validation_accuracy))
-            validation_aggregate_accuracy.append(validation_accuracy)
-            validation_aggregate_loss.append(validation_loss)
-        aggregate_loss.append(avg_loss)
-        aggregate_accuracy.append(avg_accuracy)
+            i, avg_metric))
+
+        if should_compute_metrics_this_iter(i, metrics_compute_frequency,
+                                            num_iterations):
+            # TODO: Do we need this code?
+            # _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
+            #     model_state, model_shapes)
+            # master_model.set_weights(updated_weights)
+            # Compute loss/accuracy for training data.
+            # TODO: Uncomment this once JIRA MADLIB-1332 is merged to master
+            # compute_loss_and_metrics(
+            #     schema_madlib, source_table, dependent_varname,
+            #     independent_varname, compile_params_to_pass, model_arch,
+            #     model_state, gpus_per_host, segments_per_host, seg_ids_val,
+            #     images_per_seg_val, gp_segment_id_col,
+            #     training_metrics, training_loss,
+            #     i, "Training")
+            metrics_iters.append(i)
+            if validation_set_provided:
+                # Compute loss/accuracy for validation data.
+                compute_loss_and_metrics(
+                    schema_madlib, validation_table, dependent_varname,
+                    independent_varname, compile_params_to_pass, model_arch,
+                    model_state, gpus_per_host, segments_per_host, seg_ids_val,
+                    images_per_seg_val, gp_segment_id_col,
+                    validation_metrics, validation_loss,
+                    i, "Validation")
+        training_loss.append(avg_loss)
+        training_metrics.append(avg_metric)
 
     end_training_time = datetime.datetime.now()
 
-    final_validation_acc = None
-    if validation_aggregate_accuracy and len(validation_aggregate_accuracy) > 
0:
-        final_validation_acc = validation_aggregate_accuracy[-1]
-
-    final_validation_loss = None
-    if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
-        final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
-    class_values, class_values_type = get_col_value_and_type(
-        fit_validator.source_summary_table, CLASS_VALUES_COLNAME)
-    norm_const, norm_const_type = get_col_value_and_type(
-        fit_validator.source_summary_table, NORMALIZING_CONST_COLNAME)
-    dep_vartype = plpy.execute("SELECT {0} AS dep FROM {1}".format(
-        DEPENDENT_VARTYPE_COLNAME, 
fit_validator.source_summary_table))[0]['dep']
-    dependent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'dependent_varname', 
fit_validator.source_summary_table))[0]['dependent_varname'])
-    independent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'independent_varname', 
fit_validator.source_summary_table))[0]['independent_varname'])
+    src_summary_dict = get_source_summary_table_dict(fit_validator)
+    class_values = src_summary_dict['class_values']
+    class_values_type = src_summary_dict['class_values_type']
+    norm_const = src_summary_dict['norm_const']
+    norm_const_type = src_summary_dict['norm_const_type']
+    dep_vartype = src_summary_dict['dep_vartype']
+    dependent_varname_in_source_table = 
src_summary_dict['dependent_varname_in_source_table']
+    independent_varname_in_source_table = 
src_summary_dict['independent_varname_in_source_table']
+    # Define some constants to be inserted into the summary table.
+    model_type = "madlib_keras"
+    model_size = sys.getsizeof(model)
+    metrics_iters = metrics_iters if metrics_iters else 'NULL'
+    # We always compute the training loss and metrics, at least once.
+    training_metrics_final = training_metrics[-1]
+    training_loss_final = training_loss[-1]
+    training_metrics = training_metrics if training_metrics else 'NULL'
+    training_loss = training_loss if training_loss else 'NULL'
+    # Validation loss and metrics are computed only if validation_table
+    # is provided.
+    if validation_set_provided:
+        validation_metrics_final = validation_metrics[-1]
+        validation_loss_final = validation_loss[-1]
+        validation_metrics = 'ARRAY{0}'.format(validation_metrics)
+        validation_loss = 'ARRAY{0}'.format(validation_loss)
+        # Must quote the string before inserting to table. Explicitly
+        # quoting it here since this can also take a NULL value, done
+        # in the else part.
+        validation_table = "$MAD${0}$MAD$".format(validation_table)
+    else:
+        validation_metrics = validation_loss = 'NULL'
+        validation_metrics_final = validation_loss_final = 'NULL'
+        validation_table = 'NULL'
+
     create_output_summary_table = plpy.prepare("""
-        CREATE TABLE {0}_summary AS
+        CREATE TABLE {output_summary_model_table} AS
         SELECT
-        $1 AS model_arch_table,
-        $2 AS model_arch_id,
-        $3 AS model_type,
-        $4 AS start_training_time,
-        $5 AS end_training_time,
-        $6 AS source_table,
-        $7 AS validation_table,
-        $8 AS model,
-        $9 AS dependent_varname,
-        $10 AS independent_varname,
-        $11 AS name,
-        $12 AS description,
-        $13 AS model_size,
-        $14 AS madlib_version,
-        $15 AS compile_params,
-        $16 AS fit_params,
-        $17 AS num_iterations,
-        $18 AS num_classes,
-        $19 AS accuracy,
-        $20 AS loss,
-        $21 AS accuracy_iter,
-        $22 AS loss_iter,
-        $23 AS time_iter,
-        $24 AS accuracy_validation,
-        $25 AS loss_validation,
-        $26 AS accuracy_iter_validation,
-        $27 AS loss_iter_validation,
-        $28 AS {1},
-        $29 AS {2},
-        $30 AS {3}
-        """.format(model, CLASS_VALUES_COLNAME, DEPENDENT_VARTYPE_COLNAME,
-                   NORMALIZING_CONST_COLNAME),
-                   ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
-                    "TIMESTAMP", "TEXT", "TEXT","TEXT",
-                    "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
-                    "TEXT", "TEXT", "TEXT", "INTEGER",
-                    "INTEGER", "DOUBLE PRECISION",
-                    "DOUBLE PRECISION", "DOUBLE PRECISION[]",
-                    "DOUBLE PRECISION[]", "TIMESTAMP[]",
-                    "DOUBLE PRECISION", "DOUBLE PRECISION",
-                    "DOUBLE PRECISION[]", "DOUBLE PRECISION[]",
-                    class_values_type, "TEXT", norm_const_type])
-    plpy.execute(
-        create_output_summary_table,
-        [
-            model_arch_table, model_arch_id,
-            "madlib_keras",
-            start_training_time, end_training_time,
-            source_table, validation_table,
-            model, dependent_varname_in_source_table,
-            independent_varname_in_source_table, name, description,
-            sys.getsizeof(model), version, compile_params,
-            fit_params, num_iterations, num_classes,
-            aggregate_accuracy[-1],
-            aggregate_loss[-1],
-            aggregate_accuracy, aggregate_loss,
-            aggregate_runtime, final_validation_acc,
-            final_validation_loss,
-            validation_aggregate_accuracy,
-            validation_aggregate_loss,
-            class_values,
-            dep_vartype,
-            norm_const
-        ]
-        )
+            $MAD${source_table}$MAD$::TEXT AS source_table,
+            $MAD${model}$MAD$::TEXT AS model,
+            $MAD${dependent_varname_in_source_table}$MAD$::TEXT AS 
dependent_varname,
+            $MAD${independent_varname_in_source_table}$MAD$::TEXT AS 
independent_varname,
+            $MAD${model_arch_table}$MAD$::TEXT AS model_arch_table,
+            {model_arch_id} AS model_arch_id,
+            $1 AS compile_params,
+            $2 AS fit_params,
+            {num_iterations} AS num_iterations,
+            {validation_table}::TEXT AS validation_table,
+            {metrics_compute_frequency} AS metrics_compute_frequency,
+            $3 AS name,
+            $4 AS description,
+            '{model_type}'::TEXT AS model_type,
+            {model_size} AS model_size,
+            '{start_training_time}'::TIMESTAMP AS start_training_time,
+            '{end_training_time}'::TIMESTAMP AS end_training_time,
+            $5 AS time_iter,
+            '{version}'::TEXT AS madlib_version,
+            {num_classes} AS num_classes,
+            $6 AS {class_values_colname},
+            '{dep_vartype}' AS {dependent_vartype_colname},
+            {norm_const} AS {normalizing_const_colname},
+            {training_metrics_final} AS training_metrics_final,
+            {training_loss_final} AS training_loss_final,
+            ARRAY{training_metrics}::DOUBLE PRECISION[] AS training_metrics,
+            ARRAY{training_loss}::DOUBLE PRECISION[] AS training_loss,
+            {validation_metrics_final} AS validation_metrics_final,
+            {validation_loss_final} AS validation_loss_final,
+            {validation_metrics}::DOUBLE PRECISION[] AS validation_metrics,
+            {validation_loss}::DOUBLE PRECISION[] AS validation_loss,
+            ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
+        
""".format(output_summary_model_table=fit_validator.output_summary_model_table,
+                   class_values_colname=CLASS_VALUES_COLNAME,
+                   dependent_vartype_colname=DEPENDENT_VARTYPE_COLNAME,
+                   normalizing_const_colname=NORMALIZING_CONST_COLNAME,
+                   **locals()),
+                   ["TEXT", "TEXT", "TEXT","TEXT", "TIMESTAMP[]",
+                    class_values_type])
+    plpy.execute(create_output_summary_table,
+                 [compile_params, fit_params, name,
+                  description, aggregate_runtime, class_values])
 
     create_output_table = plpy.prepare("""
         CREATE TABLE {0} AS
@@ -310,6 +331,61 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
+def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
+                             independent_varname, compile_params, model_arch,
+                             model_state, gpus_per_host, segments_per_host,
+                             seg_ids_val, rows_per_seg_val,
+                             gp_segment_id_col, metrics_list, loss_list,
+                             curr_iter, dataset_name):
+    """
+    Compute the loss and metric using a given model (model_state) on the
+    given dataset (table.)
+    """
+    start_val = time.time()
+    evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
+                                                   table,
+                                                   dependent_varname,
+                                                   independent_varname,
+                                                   compile_params,
+                                                   model_arch, model_state,
+                                                   gpus_per_host,
+                                                   segments_per_host,
+                                                   seg_ids_val,
+                                                   rows_per_seg_val,
+                                                   gp_segment_id_col)
+    end_val = time.time()
+    plpy.info("Time for evaluation in iteration {0}: {1} sec.". format(
+        curr_iter, end_val - start_val))
+    if len(evaluate_result) < 2:
+        plpy.error('Calling evaluate on table {0} returned < 2 '
+                   'metrics. Expected both loss and a metric.'.format(
+            table))
+    loss = evaluate_result[0]
+    metric = evaluate_result[1]
+    plpy.info("{0} set metric after iteration {1}: {2}.".
+              format(dataset_name, curr_iter, metric))
+    plpy.info("{0} set loss after iteration {1}: {2}.".
+              format(dataset_name, curr_iter, loss))
+    metrics_list.append(metric)
+    loss_list.append(loss)
+
+def should_compute_metrics_this_iter(curr_iter, metrics_compute_frequency,
+                                     num_iterations):
+    """
+    Check if we want to compute loss/accuracy for the current iteration
+    :param curr_iter:
+    :param metrics_compute_frequency:
+    :param num_iterations:
+    :return: Returns a boolean
+            return TRUE, if it is the last iteration, or if 
metrics_compute_frequency
+            iterations have elapsed since the last time it was computed.
+            return FALSE otherwise.
+    """
+    # Compute loss/accuracy every metrics_compute_frequency'th iteration,
+    # and also for the last iteration.
+    return (curr_iter)%metrics_compute_frequency == 0 or \
+           curr_iter == num_iterations
+
 def get_images_per_seg(source_table, dependent_varname):
     """
     Compute total images in each segment, by querying source_table.  For
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 2d7170b..77222a9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -38,6 +38,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     num_iterations          INTEGER,
     gpus_per_host           INTEGER,
     validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
     name                    VARCHAR,
     description             VARCHAR
 ) RETURNS VOID AS $$
@@ -47,6 +48,39 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
+    source_table            VARCHAR,
+    model                   VARCHAR,
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    compile_params          VARCHAR,
+    fit_params              VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    name                    VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, $11, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
+    source_table            VARCHAR,
+    model                   VARCHAR,
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    compile_params          VARCHAR,
+    fit_params              VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     source_table            VARCHAR,
@@ -59,7 +93,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     gpus_per_host           INTEGER,
     validation_table        VARCHAR
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -73,7 +107,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     num_iterations          INTEGER,
     gpus_per_host           INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, 
NULL, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, 
NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -86,7 +120,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     fit_params              VARCHAR,
     num_iterations          INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, 0, NULL, 
NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, 0, NULL, 
NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index d66219c..9210550 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -34,6 +34,7 @@ from utilities.utilities import is_var_valid
 from utilities.utilities import is_valid_psql_type
 from utilities.utilities import NUMERIC
 from utilities.utilities import ONLY_ARRAY
+from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import input_tbl_valid
@@ -157,13 +158,14 @@ class PredictInputValidator:
 class FitInputValidator:
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, dependent_varname, independent_varname,
-                 num_iterations):
+                 num_iterations, metrics_compute_frequency):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
         self.model_arch_table = model_arch_table
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
+        self.metrics_compute_frequency = metrics_compute_frequency
         self.num_iterations = num_iterations
         self.source_summary_table = None
         if self.source_table:
@@ -172,36 +174,42 @@ class FitInputValidator:
         if self.output_model_table:
             self.output_summary_model_table = add_postfix(
                 self.output_model_table, "_summary")
-        self.module_name = 'model_keras'
+        self.module_name = 'madlib_keras_fit'
         self._validate_input_args()
 
     def _validate_input_table(self, table):
         _assert(is_var_valid(table, self.independent_varname),
-                "model_keras error: invalid independent_varname "
+                "{module_name}: invalid independent_varname "
                 "('{independent_varname}') for table "
                 "({table}).".format(
+                    module_name=self.module_name,
                     independent_varname=self.independent_varname,
                     table=table))
 
         _assert(is_var_valid(table, self.dependent_varname),
-                "model_keras error: invalid dependent_varname "
+                "{module_name}: invalid dependent_varname "
                 "('{dependent_varname}') for table "
                 "({table}).".format(
+                    module_name=self.module_name,
                     dependent_varname=self.dependent_varname,
                     table=table))
 
+    def _is_valid_metrics_compute_frequency(self):
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+               self.metrics_compute_frequency <= self.num_iterations)
+
     def _validate_input_args(self):
         _assert(self.num_iterations > 0,
-            "model_keras error: Number of iterations cannot be < 1.")
+            "{0}: Number of iterations cannot be < 
1.".format(self.module_name))
+        _assert(self._is_valid_metrics_compute_frequency(),
+            "{0}: metrics_compute_frequency must be in the range (1 - 
{1}).".format(
+                self.module_name, self.num_iterations))
         input_tbl_valid(self.source_table, self.module_name)
         input_tbl_valid(self.source_summary_table, self.module_name)
-        _assert(is_var_valid(
-            self.source_summary_table, CLASS_VALUES_COLNAME),
-                "model_keras error: invalid class_values varname "
-                "('{class_values}') for source_summary_table "
-                "({source_summary_table}).".format(
-                    class_values=CLASS_VALUES_COLNAME,
-                    source_summary_table=self.source_summary_table))
+        cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME,
+            NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+            'dependent_varname', 'independent_varname'], self.module_name)
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
         validate_dependent_var_for_minibatch(self.source_table,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index bc0789b..8567637 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -137,15 +137,15 @@ SELECT assert(
         num_iterations = 3 AND
         num_classes = 2 AND
         class_values = '{0,1}' AND
-        accuracy >= 0  AND
-        loss  >= 0  AND
-        array_upper(accuracy_iter, 1) = 3 AND
-        array_upper(loss_iter, 1) = 3 AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
         array_upper(time_iter, 1) = 3 AND
-        accuracy_validation >= 0 AND
-        loss_validation  >= 0  AND
-        array_upper(accuracy_iter_validation, 1) = 3 AND
-        array_upper(loss_iter_validation, 1) = 3 ,
+        validation_metrics_final >= 0 AND
+        validation_loss_final  >= 0  AND
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 ,
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM keras_saved_out_summary) summary;
 
@@ -200,37 +200,42 @@ SELECT madlib_keras_fit(
     2,
     NULL,
     NULL,
+    NULL,
     'model name', 'model desc');
+\x on;
+select * from keras_out_summary;
 SELECT assert(
-    model_arch_table = 'model_arch' AND
-    model_arch_id = 1 AND
-    model_type = 'madlib_keras' AND
-    start_training_time         < now() AND
-    end_training_time > start_training_time AND
     source_table = 'cifar_10_sample_batched' AND
-    validation_table is NULL AND
     model = 'keras_out' AND
     dependent_varname = 'y' AND
-    dependent_vartype = 'smallint' AND
     independent_varname = 'x' AND
+    model_arch_table = 'model_arch' AND
+    model_arch_id = 1 AND
+    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
+    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+    num_iterations = 2 AND
+    validation_table is NULL AND
+    metrics_compute_frequency = 2 AND
     name = 'model name' AND
     description = 'model desc' AND
+    model_type = 'madlib_keras' AND
     model_size > 0 AND
+    start_training_time         < now() AND
+    end_training_time > start_training_time AND
+    array_upper(time_iter, 1) = 2 AND
+    dependent_vartype = 'smallint' AND
     madlib_version is NOT NULL AND
-    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
-    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
-    num_iterations = 2 AND
     num_classes = 2 AND
     class_values = '{0,1}' AND
-    accuracy is not NULL AND
-    loss is not NULL AND
-    array_upper(accuracy_iter, 1) = 2 AND
-    array_upper(loss_iter, 1) = 2 AND
-    array_upper(time_iter, 1) = 2 AND
-    accuracy_validation is  NULL AND
-    loss_validation is  NULL AND
-    array_upper(accuracy_iter_validation,1) = 0 AND
-    array_upper(loss_iter_validation,1) = 0 ,
+    normalizing_const = 255.0 AND
+    training_metrics_final is not NULL AND
+    training_loss_final is not NULL AND
+    array_upper(training_metrics, 1) = 2 AND
+    array_upper(training_loss, 1) = 2 AND
+    validation_metrics_final is  NULL AND
+    validation_loss_final is  NULL AND
+    array_upper(validation_metrics, 1) = 0 AND
+    array_upper(validation_loss, 1) = 0 ,
     'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM keras_out_summary) summary;
 
@@ -287,6 +292,7 @@ SELECT madlib_keras_fit(
     1,
     NULL,
     NULL,
+    NULL,
     'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
@@ -300,6 +306,7 @@ SELECT madlib_keras_fit(
     1,
     NULL,
     NULL,
+    NULL,
     'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
@@ -314,6 +321,7 @@ SELECT madlib_keras_fit(
     1,
     0,
     NULL,
+    NULL,
     'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
@@ -327,6 +335,7 @@ SELECT madlib_keras_fit(
     1,
     NULL,
     NULL,
+    NULL,
     'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;

Reply via email to