This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new ff7dd5a  Deep Learning: Add support for one-hot encoded dep var
ff7dd5a is described below

commit ff7dd5a5bed581ba1b97a5eff0704d6d10c06497
Author: Ekta Khanna <[email protected]>
AuthorDate: Fri Mar 22 17:23:19 2019 -0700

    Deep Learning: Add support for one-hot encoded dep var
    
    JIRA: MADLIB-1313
    The current madlib_keras_fit function for DL assumes dependent variable is 
not
    one-hot encoded. But fit uses data obtained after running
    minibatch_preprocessor_dl that returns a 1-hot encoded array for each 
dependent
    var (https://issues.apache.org/jira/browse/MADLIB-1303).
    
    This commit updates madlib_keras_fit to expect a 1-hot encoded dependent 
var to
    train the model. We also create a new column named class_values in the model
    summary table that is passed along from the input summary table (which is
    mini-batched).  Predict then uses class_values to find the exact class label
    for test data. class_values can be NULL in case where user had already 1-hot
    encoded the data before calling minibatch_preprocessor_dl. In that case, we
    return the index of the expected class label.
    
    Closes #360
    
    Co-authored-by: Nandish Jayaram <[email protected]>
---
 src/ports/postgres/modules/convex/mlp_igd.py_in    |  11 +-
 .../modules/deep_learning/madlib_keras.py_in       | 134 +++++---------------
 .../modules/deep_learning/madlib_keras.sql_in      |   6 +-
 .../deep_learning/madlib_keras_helper.py_in        | 129 +++++++++++++++++--
 .../deep_learning/madlib_keras_predict.py_in       |  62 ++++++++--
 .../modules/deep_learning/test/madlib_keras.sql_in | 136 ++++++++++++---------
 .../test/unit_tests/test_madlib_keras.py_in        |  23 +---
 .../test/unit_tests/test_madlib_keras_helper.py_in |  54 ++++++--
 .../modules/utilities/minibatch_validation.py_in   |  13 +-
 9 files changed, 348 insertions(+), 220 deletions(-)

diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in 
b/src/ports/postgres/modules/convex/mlp_igd.py_in
index 0126d31..3d93a9a 100644
--- a/src/ports/postgres/modules/convex/mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/mlp_igd.py_in
@@ -58,7 +58,7 @@ from utilities.validate_args import is_var_valid
 from utilities.validate_args import output_tbl_valid
 from utilities.validate_args import table_exists
 from utilities.validate_args import quote_ident
-from utilities.minibatch_validation import is_var_one_hot_encoded_for_minibatch
+from utilities.minibatch_validation import validate_dependent_var_for_minibatch
 
 
 @MinWarning("error")
@@ -693,14 +693,9 @@ def _validate_dependent_var(source_table, 
dependent_varname,
     classification_types = INTEGER | BOOLEAN | TEXT
 
     if is_minibatch_enabled:
-        # The dependent variable is always a double precision array in
-        # preprocessed data (so check for numeric types)
-        # strip out '[]' from expr_type
-        _assert(is_valid_psql_type(expr_type, NUMERIC | ONLY_ARRAY),
-                "Dependent variable column should be a numeric array.")
-
         if is_classification:
-            
is_var_one_hot_encoded_for_minibatch(source_table,dependent_varname)
+            validate_dependent_var_for_minibatch(
+                source_table, dependent_varname, expr_type)
     else:
         if is_classification:
             _assert((is_valid_psql_type(expr_type, NUMERIC | ONLY_ARRAY)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 32cd921..1c5145c 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -31,102 +31,22 @@ from keras.models import *
 from keras.optimizers import *
 from keras.regularizers import *
 
-from madlib_keras_helper import KerasWeightsSerializer
+from madlib_keras_helper import CLASS_VALUES_COLNAME
+from madlib_keras_helper import FitInputValidator
+from madlib_keras_helper import get_class_values_and_type
 from madlib_keras_helper import get_data_as_np_array
 from madlib_keras_wrapper import *
 
 from utilities.model_arch_info import get_input_shape
-from utilities.validate_args import input_tbl_valid
-from utilities.validate_args import output_tbl_valid
-from utilities.utilities import _assert
-from utilities.utilities import add_postfix
-from utilities.utilities import is_var_valid
 from utilities.utilities import madlib_version
 
-def _validate_input_table(source_table, independent_varname,
-                          dependent_varname):
-    _assert(is_var_valid(source_table, independent_varname),
-            "model_keras error: invalid independent_varname "
-            "('{independent_varname}') for source_table "
-            "({source_table})!".format(
-                independent_varname=independent_varname,
-                source_table=source_table))
-
-    _assert(is_var_valid(source_table, dependent_varname),
-            "model_keras error: invalid dependent_varname "
-            "('{dependent_varname}') for source_table "
-            "({source_table})!".format(
-                dependent_varname=dependent_varname, 
source_table=source_table))
-
-def _validate_input_args(
-    source_table, dependent_varname, independent_varname, model_arch_table,
-    validation_table, output_model_table, num_iterations):
-
-    module_name = 'model_keras'
-    _assert(num_iterations > 0,
-        "model_keras error: Number of iterations cannot be < 1.")
-
-    output_summary_model_table = add_postfix(output_model_table, "_summary")
-    input_tbl_valid(source_table, module_name)
-    # Source table and validation tables must have the same schema
-    _validate_input_table(source_table, independent_varname, dependent_varname)
-    if validation_table and validation_table.strip() != '':
-        input_tbl_valid(validation_table, module_name)
-        _validate_input_table(validation_table, independent_varname,
-                              dependent_varname)
-    # Validate model arch table's schema.
-    input_tbl_valid(model_arch_table, module_name)
-    # Validate output tables
-    output_tbl_valid(output_model_table, module_name)
-    output_tbl_valid(output_summary_model_table, module_name)
-
-def _validate_input_shapes(source_table, independent_varname, input_shape):
-    """
-    Validate if the input shape specified in model architecture is the same
-    as the shape of the image specified in the indepedent var of the input
-    table.
-    """
-    # The weird indexing with 'i+2' and 'i' below has two reasons:
-    # 1) The indexing for array_upper() starts from 1, but indexing in the
-    # input_shape list starts from 0.
-    # 2) Input_shape is only the image's dimension, whereas a row of
-    # independent varname in a table contains buffer size as the first
-    # dimension, followed by the image's dimension. So we must ignore
-    # the first dimension from independent varname.
-    array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
-        independent_varname, i+2, i) for i in range(len(input_shape)))
-    query = """
-        SELECT {0}
-        FROM {1}
-        LIMIT 1
-    """.format(array_upper_query, source_table)
-    # This query will fail if an image in independent var does not have the
-    # same number of dimensions as the input_shape.
-    result = plpy.execute(query)[0]
-    _assert(len(result) == len(input_shape),
-        "model_keras error: The number of dimensions ({0}) of each image in" \
-        " model architecture and {1} in {2} ({3}) do not match.".format(
-            len(input_shape), independent_varname, source_table, len(result)))
-    for i in range(len(input_shape)):
-        key_name = "n_{0}".format(i)
-        if result[key_name] != input_shape[i]:
-            # Construct the shape in independent varname to display meaningful
-            # error msg.
-            input_shape_from_table = [result["n_{0}".format(i)]
-                for i in range(len(input_shape))]
-            plpy.error("model_keras error: Input shape {0} in the model" \
-                " architecture does not match the input shape {1} of column" \
-                " {2} in table {3}.".format(
-                    input_shape, input_shape_from_table, independent_varname,
-                    source_table))
-
 def fit(schema_madlib, source_table, model, dependent_varname,
         independent_varname, model_arch_table, model_arch_id, compile_params,
         fit_params, num_iterations, num_classes, use_gpu = True,
         validation_table=None, name="", description="", **kwargs):
-    _validate_input_args(source_table, dependent_varname, independent_varname,
-                         model_arch_table, validation_table,
-                         model, num_iterations)
+    fit_validator = FitInputValidator(
+        source_table, validation_table, model, model_arch_table,
+        dependent_varname, independent_varname, num_iterations)
 
     start_training_time = datetime.datetime.now()
 
@@ -146,10 +66,9 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     query_result = query_result[0]
     model_arch = query_result['model_arch']
     input_shape = get_input_shape(model_arch)
-    _validate_input_shapes(source_table, independent_varname, input_shape)
+    fit_validator.validate_input_shapes(source_table, input_shape)
     if validation_table:
-        _validate_input_shapes(
-            validation_table, independent_varname, input_shape)
+        fit_validator.validate_input_shapes(validation_table, input_shape)
     model_weights_serialized = query_result['model_weights']
 
     # Convert model from json and initialize weights
@@ -234,13 +153,15 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
         plpy.info("Time for iteration {0}: {1} sec".
                   format(i + 1, end_iteration - start_iteration))
         aggregate_runtime.append(datetime.datetime.now())
-        avg_loss, avg_accuracy, model_state = 
KerasWeightsSerializer.deserialize_iteration_state(iteration_result)
+        avg_loss, avg_accuracy, model_state = \
+            
KerasWeightsSerializer.deserialize_iteration_state(iteration_result)
         plpy.info("Average loss after training iteration {0}: {1}".format(
             i + 1, avg_loss))
         plpy.info("Average accuracy after training iteration {0}: {1}".format(
             i + 1, avg_accuracy))
         if validation_set_provided:
-            _, _, _, updated_weights = 
KerasWeightsSerializer.deserialize_weights(model_state, model_shapes)
+            _, _, _, updated_weights = \
+                KerasWeightsSerializer.deserialize_weights(model_state, 
model_shapes)
             master_model.set_weights(updated_weights)
             (opt_name,final_args,compile_dict) = 
parse_compile_params(compile_params)
             master_model.compile(optimizer=optimizers[opt_name](**final_args),
@@ -270,6 +191,8 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
         final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
+    class_values, class_values_type = get_class_values_and_type(
+        fit_validator.source_summary_table)
     create_output_summary_table = plpy.prepare("""
         CREATE TABLE {0}_summary AS
         SELECT
@@ -299,16 +222,19 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
         $24 AS accuracy_validation,
         $25 AS loss_validation,
         $26 AS accuracy_iter_validation,
-        $27 AS loss_iter_validation
-        """.format(model), ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
-                                 "TIMESTAMP", "TEXT", "TEXT","TEXT",
-                                 "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
-                                 "TEXT", "TEXT", "TEXT", "INTEGER",
-                                 "INTEGER", "DOUBLE PRECISION",
-                                 "DOUBLE PRECISION", "DOUBLE PRECISION[]",
-                                 "DOUBLE PRECISION[]", "TIMESTAMP[]",
-                                 "DOUBLE PRECISION", "DOUBLE PRECISION",
-                                 "DOUBLE PRECISION[]", "DOUBLE PRECISION[]"])
+        $27 AS loss_iter_validation,
+        $28 AS {1}
+        """.format(model, CLASS_VALUES_COLNAME),
+                   ["TEXT", "INTEGER", "TEXT", "TIMESTAMP",
+                    "TIMESTAMP", "TEXT", "TEXT","TEXT",
+                    "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER",
+                    "TEXT", "TEXT", "TEXT", "INTEGER",
+                    "INTEGER", "DOUBLE PRECISION",
+                    "DOUBLE PRECISION", "DOUBLE PRECISION[]",
+                    "DOUBLE PRECISION[]", "TIMESTAMP[]",
+                    "DOUBLE PRECISION", "DOUBLE PRECISION",
+                    "DOUBLE PRECISION[]", "DOUBLE PRECISION[]",
+                    class_values_type])
     plpy.execute(
         create_output_summary_table,
         [
@@ -326,7 +252,8 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
             aggregate_runtime, final_validation_acc,
             final_validation_loss,
             validation_aggregate_accuracy,
-            validation_aggregate_loss
+            validation_aggregate_loss,
+            class_values
         ]
         )
 
@@ -394,7 +321,6 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, 
num_classes,
     x_train = np.array(ind_var, dtype='float64').reshape(
         len(ind_var), *input_shape)
     y_train = np.array(dep_var)
-    y_train = keras_utils.to_categorical(y_train, num_classes)
 
     # Fit segment model on data
     start_fit = time.time()
@@ -585,7 +511,7 @@ def internal_keras_evaluate(x_test, y_test, model_arch, 
model_data, input_shape,
     x_test = np.array(x_test).reshape(len(x_test), input_shape[0], 
input_shape[1],
                                       input_shape[2])
     x_test = x_test.astype('float32')
-    y_test = keras_utils.to_categorical(np.array(y_test), 10)
+    y_test = np.array(y_test)
     with K.tf.device(device_name):
         res = model.evaluate(x_test, y_test)
     plpy.info('evaluate result from internal_keras_evaluate is {}'.format(res))
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 6b2daf7..aebe270 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -199,7 +199,8 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_predict(
    model_architecture TEXT,
    model_data bytea,
    input_shape integer[],
-   compile_params TEXT
+   compile_params TEXT,
+   class_values TEXT[]
 ) RETURNS DOUBLE PRECISION[] AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
@@ -208,7 +209,8 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_predict(
                model_architecture,
                model_data,
                input_shape,
-               compile_params)
+               compile_params,
+               class_values)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index a198b7a..bd3963e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -24,6 +24,14 @@ import plpy
 # Import needed for get_data_as_np_array()
 from keras import utils as keras_utils
 
+from utilities.minibatch_validation import validate_dependent_var_for_minibatch
+from utilities.utilities import _assert
+from utilities.utilities import add_postfix
+from utilities.utilities import is_var_valid
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+
 #######################################################################
 ########### Helper functions to serialize and deserialize weights #####
 #######################################################################
@@ -174,21 +182,128 @@ def get_data_as_np_array(table_name, y, x, input_shape, 
num_classes):
     indep_len = len(val_data[0][x])
     pixels_per_image = int(input_shape[0] * input_shape[1] * input_shape[2])
     x_validation = np.ndarray((0,indep_len, pixels_per_image))
-    y_validation = np.ndarray((0,indep_len))
+    y_validation = np.ndarray((0,indep_len, num_classes))
     for i in range(len(val_data)):
         x_test = np.asarray((val_data[i][x],))
         x_test = x_test.reshape(1, indep_len, pixels_per_image)
         y_test = np.asarray((val_data[i][y],))
-        y_test = y_test.reshape(1, indep_len)
         x_validation=np.concatenate((x_validation, x_test))
         y_validation=np.concatenate((y_validation, y_test))
     num_test_examples = x_validation.shape[0]
     x_validation = x_validation.reshape(indep_len * num_test_examples, 
*input_shape)
     x_validation = x_validation.astype('float64')
-    y_validation = y_validation.reshape(indep_len * num_test_examples)
-
-    x_validation = x_validation.astype('float64')
-    #x_validation /= 255.0
-    y_validation = keras_utils.to_categorical(y_validation, num_classes)
+    y_validation = y_validation.reshape(indep_len * num_test_examples, 
num_classes)
 
     return x_validation, y_validation
+
+CLASS_VALUES_COLNAME = "class_values"
+class FitInputValidator:
+    def __init__(self, source_table, validation_table, output_model_table,
+                 model_arch_table, dependent_varname, independent_varname,
+                 num_iterations):
+        self.source_table = source_table
+        self.validation_table = validation_table
+        self.output_model_table = output_model_table
+        self.model_arch_table = model_arch_table
+        self.dependent_varname = dependent_varname
+        self.independent_varname = independent_varname
+        self.num_iterations = num_iterations
+        self.source_summary_table = None
+        if self.source_table:
+            self.source_summary_table = add_postfix(
+                self.source_table, "_summary")
+        if self.output_model_table:
+            self.output_summary_model_table = add_postfix(
+                self.output_model_table, "_summary")
+        self.module_name = 'model_keras'
+        self._validate_input_args()
+
+    def _validate_input_table(self, table):
+        _assert(is_var_valid(table, self.independent_varname),
+                "model_keras error: invalid independent_varname "
+                "('{independent_varname}') for table "
+                "({table}).".format(
+                    independent_varname=self.independent_varname,
+                    table=table))
+
+        _assert(is_var_valid(table, self.dependent_varname),
+                "model_keras error: invalid dependent_varname "
+                "('{dependent_varname}') for table "
+                "({table}).".format(
+                    dependent_varname=self.dependent_varname,
+                    table=table))
+
+    def _validate_input_args(self):
+        _assert(self.num_iterations > 0,
+            "model_keras error: Number of iterations cannot be < 1.")
+        input_tbl_valid(self.source_table, self.module_name)
+        input_tbl_valid(self.source_summary_table, self.module_name)
+        _assert(is_var_valid(
+            self.source_summary_table, CLASS_VALUES_COLNAME),
+                "model_keras error: invalid class_values varname "
+                "('{class_values}') for source_summary_table "
+                "({source_summary_table}).".format(
+                    class_values=CLASS_VALUES_COLNAME,
+                    source_summary_table=self.source_summary_table))
+        # Source table and validation tables must have the same schema
+        self._validate_input_table(self.source_table)
+        validate_dependent_var_for_minibatch(self.source_table,
+                                             self.dependent_varname)
+        if self.validation_table and self.validation_table.strip() != '':
+            input_tbl_valid(self.validation_table, self.module_name)
+            self._validate_input_table(self.validation_table)
+            validate_dependent_var_for_minibatch(self.validation_table,
+                                                 self.dependent_varname)
+        # Validate model arch table's schema.
+        input_tbl_valid(self.model_arch_table, self.module_name)
+        # Validate output tables
+        output_tbl_valid(self.output_model_table, self.module_name)
+        output_tbl_valid(self.output_summary_model_table, self.module_name)
+
+    def validate_input_shapes(self, table, input_shape):
+        """
+        Validate if the input shape specified in model architecture is the same
+        as the shape of the image specified in the indepedent var of the input
+        table.
+        """
+        # The weird indexing with 'i+2' and 'i' below has two reasons:
+        # 1) The indexing for array_upper() starts from 1, but indexing in the
+        # input_shape list starts from 0.
+        # 2) Input_shape is only the image's dimension, whereas a row of
+        # independent varname in a table contains buffer size as the first
+        # dimension, followed by the image's dimension. So we must ignore
+        # the first dimension from independent varname.
+        array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+            self.independent_varname, i+2, i) for i in range(len(input_shape)))
+        query = """
+            SELECT {0}
+            FROM {1}
+            LIMIT 1
+        """.format(array_upper_query, table)
+        # This query will fail if an image in independent var does not have the
+        # same number of dimensions as the input_shape.
+        result = plpy.execute(query)[0]
+        _assert(len(result) == len(input_shape),
+            "model_keras error: The number of dimensions ({0}) of each image" \
+            " in model architecture and {1} in {2} ({3}) do not match.".format(
+                len(input_shape), self.independent_varname, table, 
len(result)))
+        for i in range(len(input_shape)):
+            key_name = "n_{0}".format(i)
+            if result[key_name] != input_shape[i]:
+                # Construct the shape in independent varname to display
+                # meaningful error msg.
+                input_shape_from_table = [result["n_{0}".format(i)]
+                    for i in range(len(input_shape))]
+                plpy.error("model_keras error: Input shape {0} in the model" \
+                    " architecture does not match the input shape {1} of 
column" \
+                    " {2} in table {3}.".format(
+                        input_shape, input_shape_from_table,
+                        self.independent_varname, table))
+
+def get_class_values_and_type(source_summary_table):
+    class_values = plpy.execute("SELECT {0} AS class_values FROM {1}".
+        format(CLASS_VALUES_COLNAME, source_summary_table)
+        )[0]['class_values']
+    class_values_type = get_expr_type(CLASS_VALUES_COLNAME,
+                                      source_summary_table)
+    return class_values, class_values_type
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index e2ed883..5e4e62b 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -28,10 +28,13 @@ from keras.optimizers import *
 import numpy as np
 
 from utilities.model_arch_info import get_input_shape
+from utilities.utilities import add_postfix
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 
-from madlib_keras_helper import convert_string_of_args_to_dict
+from madlib_keras_wrapper import convert_string_of_args_to_dict
+from madlib_keras_helper import get_class_values_and_type
+from madlib_keras_helper import KerasWeightsSerializer
 
 def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
             model_arch_id, independent_varname, compile_params, output_table,
@@ -55,15 +58,23 @@ def predict(schema_madlib, model_table, test_table, id_col, 
model_arch_table,
     model_arch = query_result['model_arch']
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
-    predict_query = plpy.prepare("""create table {output_table} as
-        select {id_col}, (madlib.internal_keras_predict({independent_varname},
-                                             $MAD${model_arch}$MAD$,
-                                             $1,ARRAY{input_shape},
-                                             {compile_params}))[1] as 
prediction
+    model_summary_table = add_postfix(model_table, "_summary")
+    class_values, _ = get_class_values_and_type(model_summary_table)
+    predict_query = plpy.prepare("""
+        CREATE TABLE {output_table} AS
+        SELECT {id_col},
+            ({schema_madlib}.internal_keras_predict
+                ({independent_varname},
+                 $MAD${model_arch}$MAD$,
+                 $1,ARRAY{input_shape},
+                 {compile_params},
+                 ARRAY{class_values}::TEXT[])
+            )[1] as prediction
         from {test_table}""".format(**locals()), ["bytea"])
     plpy.execute(predict_query, [model_data])
 
-def internal_keras_predict(x_test, model_arch, model_data, input_shape, 
compile_params):
+def internal_keras_predict(x_test, model_arch, model_data, input_shape,
+                           compile_params, class_values):
     model = model_from_json(model_arch)
     compile_params = convert_string_of_args_to_dict(compile_params)
     device_name = '/cpu:0'
@@ -75,9 +86,40 @@ def internal_keras_predict(x_test, model_arch, model_data, 
input_shape, compile_
     model_shapes = []
     for weight_arr in model.get_weights():
         model_shapes.append(weight_arr.shape)
-    _,_,_, model_weights = deserialize_weights(model_data, model_shapes)
+    _,_,_, model_weights = KerasWeightsSerializer.deserialize_weights(
+        model_data, model_shapes)
     model.set_weights(model_weights)
     x_test = np.array(x_test).reshape(1, *input_shape)
     x_test /= 255
-    res = model.predict_classes(x_test)
-    return res
+    proba_argmax = model.predict_classes(x_test)
+    # proba_argmax is a list with exactly one element in it. That element
+    # refers to the index containing the largest probability value in the
+    # output of Keras' predict function.
+    return _get_class_label(class_values, proba_argmax[0])
+
+def _get_class_label(class_values, class_index):
+    """
+    Returns back the class label associated with the index returned by Keras'
+    predict_classes function. Keras' predict_classes function returns back
+    the index of the 1-hot encoded output that has the highest probability
+    value. We should infer the exact class label corresponding to the index
+    by looking at the class_values list (which is obtained from the
+    class_values column of the model summary table). If class_values is None,
+    we return the index as is.
+    Args:
+        @param class_values: list of class labels.
+        @param class_index: integer representing the index with max
+                            probability value.
+    Returns:
+        scalar. If class_values is None, returns class_index, else returns
+        class_values[class_index].
+    """
+    if class_values:
+        if class_index < len(class_values):
+            return class_values[class_index]
+        else:
+            plpy.error("Invalid class index {0} returned from Keras predict. "\
+                "Index value must be less than {1}".format(
+                    class_index, len(class_values)))
+    else:
+        return class_index
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index fbf6a81..f393db3 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -28,7 +28,7 @@ create table cifar_10_sample(
         );
 copy cifar_10_sample from stdin delimiter '|';
 
1|{{{202,204,199},{202,204,199},{204,206,201},{206,208,203},{208,210,205},{209,211,206},{210,212,207},{212,214,210},{213,215,212},{215,217,214},{216,218,215},{216,218,215},{215,217,214},{216,218,215},{216,218,215},{216,218,214},{217,219,214},{217,219,214},{218,220,215},{218,219,214},{216,217,212},{217,218,213},{218,219,214},{214,215,209},{213,214,207},{212,213,206},{211,212,205},{209,210,203},{208,209,202},{207,208,200},{205,206,199},{203,204,198}},{{206,208,203},{206,208,203},{207,209,2
 [...]
-2|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119,112,104},{121,114,10
 [...]
+2|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119,112,104},{121,114,10
 [...]
 \.
 
 DROP TABLE IF EXISTS cifar_10_sample_batched;
@@ -54,7 +54,7 @@ SELECT load_keras_model('model_arch',
        {"class_name": "Dropout", "config": {"rate": 0.25, "noise_shape": null, 
"trainable": true, "seed": null, "name": "dropout_1"}},
        {"class_name": "Flatten", "config": {"trainable": true, "name": 
"flatten_1", "data_format": "channels_last"}},
        {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
-       {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true, 
"activity_regularizer": null}
+       {"class_name": "Zeros", "config": {}}, "units": 2, "use_bias": true, 
"activity_regularizer": null}
        }], "backend": "tensorflow"}$$);
 ALTER TABLE model_arch RENAME model_id TO id;
 
@@ -62,18 +62,19 @@ ALTER TABLE model_arch RENAME model_id TO id;
 -- It might break the assertion
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit('cifar_10_sample_batched',
-              'keras_out',
-              'dependent_var',
-              'independent_var',
-              'model_arch',
-              1,
-              $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-              $$ batch_size=2, epochs=1, verbose=0 $$::text,
-              3,
-              10,
-              FALSE,
-              'cifar_10_sample_batched');
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    2,
+    FALSE,
+    'cifar_10_sample_batched');
 SELECT assert(
         model_arch_table = 'model_arch' AND
         model_arch_id = 1 AND
@@ -92,7 +93,8 @@ SELECT assert(
         compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
         fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
         num_iterations = 3 AND
-        num_classes = 10 AND
+        num_classes = 2 AND
+        class_values = '{0,1}' AND
         accuracy is not NULL AND
         loss is not NULL AND
         array_upper(accuracy_iter, 1) = 3 AND
@@ -103,57 +105,71 @@ SELECT assert(
         array_upper(accuracy_iter_validation, 1) = 3 AND
         array_upper(loss_iter_validation, 1) = 3 ,
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-from (select * from keras_out_summary) summary;
+FROM (SELECT * FROM keras_out_summary) summary;
 
-SELECT assert(model_data is not NULL , 'Keras model output validation failed') 
from (select * from keras_out) k;
+SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
 
 
 -- Test for
   -- Non null name and description columns
        -- Null validation table
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
-SELECT madlib_keras_fit('cifar_10_sample_batched',
-                                               'keras_out',
-                                               'dependent_var',
-                                               'independent_var',
-                                               'model_arch',
-                                               1,
-                        $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
-                        $$ batch_size=2, epochs=1, verbose=0 $$::text,
-                                               1,
-                                               10,
-                                               FALSE,
-                                               NULL,
-                                               'model name', 'model desc');
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    2,
+    FALSE,
+    NULL,
+    'model name', 'model desc');
 SELECT assert(
-        model_arch_table = 'model_arch' AND
-        model_arch_id = 1 AND
-        model_type = 'madlib_keras' AND
-        start_training_time         < now() AND
-        end_training_time > start_training_time AND
-        source_table = 'cifar_10_sample_batched' AND
-        validation_table = 'cifar_10_sample_batched' AND
-        model = 'keras_out' AND
-        dependent_varname = 'dependent_var' AND
-        independent_varname = 'independent_var' AND
-        name = 'model name' AND
-        description = 'model desc' AND
-        model_size > 0 AND
-        madlib_version is NOT NULL AND
-        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
-        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
-        num_iterations = 1 AND
-        num_classes = 10 AND
-        accuracy is not NULL AND
-        loss is not NULL AND
-        array_upper(accuracy_iter, 1) = 1 AND
-        array_upper(loss_iter, 1) = 1 AND
-        array_upper(time_iter, 1) = 1 AND
-        accuracy_validation is  NULL AND
-        loss_validation is  NULL AND
-        array_upper(accuracy_iter_validation,1) = 0 AND
-        array_upper(loss_iter_validation,1) = 0 ,
-        'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-from (select * from keras_out_summary) summary;
+    model_arch_table = 'model_arch' AND
+    model_arch_id = 1 AND
+    model_type = 'madlib_keras' AND
+    start_training_time         < now() AND
+    end_training_time > start_training_time AND
+    source_table = 'cifar_10_sample_batched' AND
+    validation_table = 'cifar_10_sample_batched' AND
+    model = 'keras_out' AND
+    dependent_varname = 'dependent_var' AND
+    independent_varname = 'independent_var' AND
+    name = 'model name' AND
+    description = 'model desc' AND
+    model_size > 0 AND
+    madlib_version is NOT NULL AND
+    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
+    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+    num_iterations = 1 AND
+    num_classes = 2 AND
+    class_values = '{0,1}' AND
+    accuracy is not NULL AND
+    loss is not NULL AND
+    array_upper(accuracy_iter, 1) = 1 AND
+    array_upper(loss_iter, 1) = 1 AND
+    array_upper(time_iter, 1) = 1 AND
+    accuracy_validation is  NULL AND
+    loss_validation is  NULL AND
+    array_upper(accuracy_iter_validation,1) = 0 AND
+    array_upper(loss_iter_validation,1) = 0 ,
+    'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM keras_out_summary) summary;
+
+SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
 
-SELECT assert(model_data is not NULL , 'Keras model output validation failed') 
from (select * from keras_out) k;
+-- Temporary predict test, to be updated as part of another jira
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_out',
+    'cifar_10_sample',
+    'id',
+    'model_arch',
+    1,
+    'x',
+    '''optimizer''=SGD(lr=0.01, decay=1e-6, nesterov=True), 
''loss''=''categorical_crossentropy'', ''metrics''=[''accuracy'']'::text,
+    'cifar10_predict');
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index d9613ca..a66a292 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -82,7 +82,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {'buffer_count': buffer_count}}
         new_model_state = self.subject.fit_transition(
-            None, [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            None, [[0.5]] , [[0,1]], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
             self.model.to_json(), self.compile_params, self.fit_params, False,
             previous_state.tostring(), **k)
         buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
@@ -113,7 +113,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
                    'model_shapes': self.model_shapes}}
         k['SD']['segment_model'] = self.model
         new_model_state = self.subject.fit_transition(
-            state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            state.tostring(), [[0.5]] , [[1,0]], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
             self.model.to_json(), None, self.fit_params, False, 
'dummy_previous_state', **k)
 
         buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
@@ -142,7 +142,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
                    'model_shapes': self.model_shapes}}
         k['SD']['segment_model'] = self.model
         new_model_state = self.subject.fit_transition(
-            state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            state.tostring(), [[0.5]] , [[0,1]], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
             self.model.to_json(), None, self.fit_params, False, 
'dummy_previous_state', **k)
 
         buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
@@ -177,23 +177,6 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
             'dummy_prev_state', **k))
 
-    def test_validate_input_shapes_shapes_do_not_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
-
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 
32}]
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
-
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': 
None}]
-        with self.assertRaises(plpy.PLPYException):
-            self.subject._validate_input_shapes('foo', 'bar', [3,32])
-
-    def test_validate_input_shapes_shapes_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 
3}]
-        self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
-
 if __name__ == '__main__':
     unittest.main()
 # ---------------------------------------------------------------------
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
index b9feef8..61439f1 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
@@ -25,6 +25,8 @@ 
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath
 
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
 
 from keras import utils as keras_utils
+from keras.models import *
+from keras.layers import *
 
 import unittest
 from mock import *
@@ -139,8 +141,8 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
                          res)
 
     def test_get_data_as_np_array_one_image_per_row(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': [[1, 0, 
0]]},
+                                               {'x': [[5,6]], 'y': [[0, 1, 
0]]}]
         x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
                                                          3)
         self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
@@ -149,8 +151,8 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
                          y_res.tolist())
 
     def test_get_data_as_np_array_multiple_images_per_row(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2], [3,4]], 'y': 
[0,2]},
-                                               {'x': [[5,6], [7,8]], 'y': 
[1,0]}]
+        self.plpy_mock_execute.return_value = [{'x': [[1,2], [3,4]], 'y': 
[[1,0,0],[0,0,1]]},
+                                               {'x': [[5,6], [7,8]], 'y': 
[[0,1,0],[1,0,0]]}]
         x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
                                                          3)
         self.assertEqual(np.array([[[[1,2]]], [[[3,4]]],
@@ -161,8 +163,8 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
                          y_res.tolist())
 
     def test_get_data_as_np_array_float_input_shape(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': [[1, 0, 
0]]},
+                                               {'x': [[5,6]], 'y': [[0, 1, 
0]]}]
         x_res, y_res = self.subject.get_data_as_np_array('foo','y','x',
                                                          [1.5,1.9,2.3], 3)
         self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
@@ -171,9 +173,45 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
                          y_res.tolist())
 
     def test_get_data_as_np_array_invalid_input_shape(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': [[1, 0, 
0]]},
+                                               {'x': [[5,6]], 'y': [[0, 1, 
0]]}]
         # we expect keras failure(ValueError) because we cannot reshape
         # the input which is of size 2 to input shape of 1,1,3
         with self.assertRaises(ValueError):
             self.subject.get_data_as_np_array('foo','y','x', [1,1,3], 3)
+
+    def test_validate_input_shapes_shapes_do_not_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
+        self.subject.FitInputValidator._validate_input_args = Mock()
+        input_validator_obj = self.subject.FitInputValidator('foo',
+                                                             'foo_valid',
+                                                             'model',
+                                                             
'model_arch_table',
+                                                             
'dependent_varname',
+                                                             
'independent_varname',
+                                                             1)
+        with self.assertRaises(plpy.PLPYException):
+            input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3])
+
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 
32}]
+        with self.assertRaises(plpy.PLPYException):
+            input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3])
+
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': 
None}]
+        with self.assertRaises(plpy.PLPYException):
+            input_validator_obj.validate_input_shapes('dummy_tbl', [3,32])
+
+    def test_validate_input_shapes_shapes_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 
3}]
+        self.subject.FitInputValidator._validate_input_args = Mock()
+        input_validator_obj = self.subject.FitInputValidator('foo',
+                                                             'foo_valid',
+                                                             'model',
+                                                             
'model_arch_table',
+                                                             
'dependent_varname',
+                                                             
'independent_varname',
+                                                             1)
+        input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3])
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in 
b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
index 16b11a9..5c77c2c 100644
--- a/src/ports/postgres/modules/utilities/minibatch_validation.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
@@ -18,8 +18,19 @@
 # under the License.
 
 import plpy
+from utilities import _assert
+from utilities import is_valid_psql_type
+from utilities import NUMERIC, ONLY_ARRAY
+from validate_args import get_expr_type
+
+def validate_dependent_var_for_minibatch(table_name, var_name, expr_type=None):
+    # The dependent variable is always a double precision array in
+    # preprocessed data (so check for numeric types)
+    if not expr_type:
+        expr_type = get_expr_type(var_name, table_name)
+    _assert(is_valid_psql_type(expr_type, NUMERIC | ONLY_ARRAY),
+            "Dependent variable column should be a numeric array.")
 
-def is_var_one_hot_encoded_for_minibatch(table_name, var_name):
     query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM
               {table_name} LIMIT 1;""".format(**locals())
     result = plpy.execute(query)

Reply via email to