kaknikhil commented on a change in pull request #388:  DL: Add new param 
metrics_compute_frequency to madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/388#discussion_r283041422
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
 ##########
 @@ -173,134 +179,131 @@ def fit(schema_madlib, source_table, 
model,model_arch_table,
     # Run distributed training for specified number of iterations
     for i in range(num_iterations):
         start_iteration = time.time()
-        iteration_result = plpy.execute(run_training_iteration, 
[model_state])[0]['iteration_result']
+        iteration_result = plpy.execute(run_training_iteration,
+                                        [model_state])[0]['iteration_result']
         end_iteration = time.time()
         plpy.info("Time for iteration {0}: {1} sec".
                   format(i + 1, end_iteration - start_iteration))
         aggregate_runtime.append(datetime.datetime.now())
-        avg_loss, avg_accuracy, model_state = 
madlib_keras_serializer.deserialize_iteration_state(iteration_result)
+        avg_loss, avg_metric, model_state = madlib_keras_serializer.\
+            deserialize_iteration_state(iteration_result)
         plpy.info("Average loss after training iteration {0}: {1}".format(
             i + 1, avg_loss))
         plpy.info("Average accuracy after training iteration {0}: {1}".format(
-            i + 1, avg_accuracy))
-        if validation_set_provided:
-            _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
-                model_state, model_shapes)
-            master_model.set_weights(updated_weights)
-            start_val = time.time()
-            evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
-                                                           validation_table,
-                                                           dependent_varname,
-                                                           independent_varname,
-                                                           
compile_params_to_pass,
-                                                           model_arch, 
model_state,
-                                                           gpus_per_host,
-                                                           segments_per_host,
-                                                           seg_ids_val,
-                                                           rows_per_seg_val,
-                                                           gp_segment_id_col)
-            end_val = time.time()
-            plpy.info("Time for validation in iteration {0}: {1} sec". 
format(i + 1, end_val - start_val))
-            if len(evaluate_result) < 2:
-                plpy.error('Calling evaluate on validation data returned < 2 '
-                           'metrics. Expected metrics are loss and accuracy')
-            validation_loss = evaluate_result[0]
-            validation_accuracy = evaluate_result[1]
-            plpy.info("Validation set accuracy after iteration {0}: {1}".
-                      format(i + 1, validation_accuracy))
-            validation_aggregate_accuracy.append(validation_accuracy)
-            validation_aggregate_loss.append(validation_loss)
-        aggregate_loss.append(avg_loss)
-        aggregate_accuracy.append(avg_accuracy)
+            i + 1, avg_metric))
+
+        if should_compute_metrics_this_iter(i, metrics_compute_frequency,
+                                            num_iterations):
+            # TODO: Do we need this code?
+            # _, _, _, updated_weights = 
madlib_keras_serializer.deserialize_weights(
+            #     model_state, model_shapes)
+            # master_model.set_weights(updated_weights)
+            # Compute loss/accuracy for training data.
+            # TODO: Uncomment this once JIRA MADLIB-1332 is merged to master
+            # compute_loss_and_metrics(
+            #     schema_madlib, source_table, dependent_varname,
+            #     independent_varname, compile_params_to_pass, model_arch,
+            #     model_state, gpus_per_host, segments_per_host, seg_ids_val,
+            #     rows_per_seg_val, gp_segment_id_col,
+            #     training_metrics, training_loss,
+            #     i, "Training")
+            metrics_iters.append(i)
+            if validation_set_provided:
+                # Compute loss/accuracy for validation data.
+                compute_loss_and_metrics(
+                    schema_madlib, validation_table, dependent_varname,
+                    independent_varname, compile_params_to_pass, model_arch,
+                    model_state, gpus_per_host, segments_per_host, seg_ids_val,
+                    rows_per_seg_val, gp_segment_id_col,
+                    validation_metrics, validation_loss,
+                    i, "Validation")
+        training_loss.append(avg_loss)
+        training_metrics.append(avg_metric)
 
     end_training_time = datetime.datetime.now()
 
-    final_validation_acc = None
-    if validation_aggregate_accuracy and len(validation_aggregate_accuracy) > 
0:
-        final_validation_acc = validation_aggregate_accuracy[-1]
-
-    final_validation_loss = None
-    if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
-        final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
     class_values, class_values_type = get_col_value_and_type(
         fit_validator.source_summary_table, CLASS_VALUES_COLNAME)
     norm_const, norm_const_type = get_col_value_and_type(
         fit_validator.source_summary_table, NORMALIZING_CONST_COLNAME)
     dep_vartype = plpy.execute("SELECT {0} AS dep FROM {1}".format(
         DEPENDENT_VARTYPE_COLNAME, 
fit_validator.source_summary_table))[0]['dep']
-    dependent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'dependent_varname', 
fit_validator.source_summary_table))[0]['dependent_varname'])
-    independent_varname_in_source_table = quote_ident(plpy.execute("SELECT {0} 
FROM {1}".format(
-        'independent_varname', 
fit_validator.source_summary_table))[0]['independent_varname'])
+    # Quote_ident TEXT values to be inserted into the summary table
+    dependent_varname_in_source_table = plpy.execute("SELECT {0} FROM 
{1}".format(
+        'dependent_varname', 
fit_validator.source_summary_table))[0]['dependent_varname']
+    independent_varname_in_source_table = plpy.execute("SELECT {0} FROM 
{1}".format(
+        'independent_varname', 
fit_validator.source_summary_table))[0]['independent_varname']
+    # Define some constants to be inserted into the summary table.
+    model_type = "madlib_keras"
+    model_size = sys.getsizeof(model)
+    metrics_iters = metrics_iters if metrics_iters else 'NULL'
+    # We always compute the training loss and metrics, at least once.
+    training_metrics_final = training_metrics[-1]
+    training_loss_final = training_loss[-1]
+    training_metrics = training_metrics if training_metrics else 'NULL'
+    training_loss = training_loss if training_loss else 'NULL'
+    # Validation loss and metrics are computed only if validation_table
+    # is provided.
+    if validation_set_provided:
+        validation_metrics_final = validation_metrics[-1]
+        validation_loss_final = validation_loss[-1]
+        validation_metrics = 'ARRAY{0}'.format(validation_metrics)
+        validation_loss = 'ARRAY{0}'.format(validation_loss)
+        # Must quote the string before inserting to table. Explicitly
+        # quoting it here since this can also take a NULL value, done
+        # in the else part.
+        validation_table = "$MAD${0}$MAD$".format(validation_table)
+    else:
+        validation_metrics = validation_loss = 'NULL'
+        validation_metrics_final = validation_loss_final = 'NULL'
+        validation_table = 'NULL'
+
     create_output_summary_table = plpy.prepare("""
-        CREATE TABLE {0}_summary AS
+        CREATE TABLE {output_summary_model_table} AS
         SELECT
-        $1 AS model_arch_table,
-        $2 AS model_arch_id,
-        $3 AS model_type,
-        $4 AS start_training_time,
-        $5 AS end_training_time,
-        $6 AS source_table,
-        $7 AS validation_table,
-        $8 AS model,
-        $9 AS dependent_varname,
-        $10 AS independent_varname,
-        $11 AS name,
-        $12 AS description,
-        $13 AS model_size,
-        $14 AS madlib_version,
-        $15 AS compile_params,
-        $16 AS fit_params,
-        $17 AS num_iterations,
-        $18 AS num_classes,
-        $19 AS accuracy,
-        $20 AS loss,
-        $21 AS accuracy_iter,
-        $22 AS loss_iter,
-        $23 AS time_iter,
-        $24 AS accuracy_validation,
-        $25 AS loss_validation,
-        $26 AS accuracy_iter_validation,
-        $27 AS loss_iter_validation,
-        $28 AS {1},
-        $29 AS {2},
-        $30 AS {3}
-        """.format(model, CLASS_VALUES_COLNAME, DEPENDENT_VARTYPE_COLNAME,
-                   NORMALIZING_CONST_COLNAME),
-                   ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
-                    "TIMESTAMP", "TEXT", "TEXT","TEXT",
-                    "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
-                    "TEXT", "TEXT", "TEXT", "INTEGER",
-                    "INTEGER", "DOUBLE PRECISION",
-                    "DOUBLE PRECISION", "DOUBLE PRECISION[]",
-                    "DOUBLE PRECISION[]", "TIMESTAMP[]",
-                    "DOUBLE PRECISION", "DOUBLE PRECISION",
-                    "DOUBLE PRECISION[]", "DOUBLE PRECISION[]",
-                    class_values_type, "TEXT", norm_const_type])
-    plpy.execute(
-        create_output_summary_table,
-        [
-            model_arch_table, model_arch_id,
-            "madlib_keras",
-            start_training_time, end_training_time,
-            source_table, validation_table,
-            model, dependent_varname_in_source_table,
-            independent_varname_in_source_table, name, description,
-            sys.getsizeof(model), version, compile_params,
-            fit_params, num_iterations, num_classes,
-            aggregate_accuracy[-1],
-            aggregate_loss[-1],
-            aggregate_accuracy, aggregate_loss,
-            aggregate_runtime, final_validation_acc,
-            final_validation_loss,
-            validation_aggregate_accuracy,
-            validation_aggregate_loss,
-            class_values,
-            dep_vartype,
-            norm_const
-        ]
-        )
+            $MAD${source_table}$MAD$::TEXT AS source_table,
+            $MAD${model}$MAD$::TEXT AS model,
+            $MAD${dependent_varname_in_source_table}$MAD$::TEXT AS 
dependent_varname,
+            $MAD${independent_varname_in_source_table}$MAD$::TEXT AS 
independent_varname,
+            $MAD${model_arch_table}$MAD$::TEXT AS model_arch_table,
+            {model_arch_id} AS model_arch_id,
+            $1 AS compile_params,
+            $2 AS fit_params,
+            {num_iterations} AS num_iterations,
+            {validation_table}::TEXT AS validation_table,
+            {metrics_compute_frequency} AS metrics_compute_frequency,
+            $3 AS name,
+            $4 AS description,
+            '{model_type}'::TEXT AS model_type,
+            {model_size} AS model_size,
+            '{start_training_time}'::TIMESTAMP AS start_training_time,
 
 Review comment:
   do we really need to quote the timestamp columns ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to