This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit c8346e2cdf7e52ce919eb6f5d8bc80ecce0d7040
Author: Nikhil Kak <[email protected]>
AuthorDate: Wed May 22 15:44:27 2019 -0700

    DL: Enable warm start
    
    JIRA: MADLIB-1350
    Add code to enable warm start by adding a new param named warm_start to
    the fit function, and making other necessary validation and code
    changes.
    This commit also adds dev check for both warm start and
    transfer learning. We replaced the existing transfer learning test with
    a more deterministic test for the same by using iris (a more real world
    dataset.)
    
    Co-authored-by: Nikhil Kak <[email protected]>
    Co-authored-by: Orhan Kislal <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       |  60 +--
 .../modules/deep_learning/madlib_keras.sql_in      |  32 +-
 .../deep_learning/madlib_keras_validator.py_in     |  14 +-
 .../modules/deep_learning/test/madlib_keras.sql_in | 427 ++++++++++++++++-----
 .../test/unit_tests/test_madlib_keras.py_in        |   8 +-
 5 files changed, 403 insertions(+), 138 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index f164ce7..0f85d9e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -64,7 +64,7 @@ from utilities.control import MinWarning
 def fit(schema_madlib, source_table, model, model_arch_table,
         model_arch_id, compile_params, fit_params, num_iterations,
         gpus_per_host=0, validation_table=None,
-        metrics_compute_frequency=None, name="",
+        metrics_compute_frequency=None, warm_start=False, name="",
         description="", **kwargs):
     source_table = quote_ident(source_table)
     model_arch_table = quote_ident(model_arch_table)
@@ -77,7 +77,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
         mb_dep_var_col, mb_indep_var_col,
-        num_iterations, metrics_compute_frequency)
+        num_iterations, metrics_compute_frequency, warm_start)
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
 
@@ -86,6 +86,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     start_training_time = datetime.datetime.now()
 
     segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
+    warm_start = bool(warm_start)
 
     #TODO add a unit test for this in a future PR
     # save the original value of the env variable so that we can reset it 
later.
@@ -96,9 +97,9 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     # Get the serialized master model
     start_deserialization = time.time()
     model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
-                                        ModelArchSchema.MODEL_ARCH, 
ModelArchSchema.MODEL_WEIGHTS,
-                                        model_arch_table, 
ModelArchSchema.MODEL_ID,
-                                        model_arch_id)
+        ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
+        model_arch_table, ModelArchSchema.MODEL_ID,
+        model_arch_id)
     model_arch_result = plpy.execute(model_arch_query)
     if not  model_arch_result:
         plpy.error("no model arch found in table {0} with id {1}".format(
@@ -108,9 +109,11 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     input_shape = get_input_shape(model_arch)
     num_classes = get_num_classes(model_arch)
     fit_validator.validate_input_shapes(input_shape)
-    model_weights_serialized = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
 
-    #TODO: Refactor the pg related logic in a future PR when we think
+    serialized_weights = get_initial_weights(model, model_arch_result,
+                                             warm_start)
+
+    # TODO: Refactor the pg related logic in a future PR when we think
     # about making the fit function easier to read and maintain.
     if is_platform_pg():
         gp_segment_id_col = '0'
@@ -127,21 +130,6 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     if validation_table:
         seg_ids_val, images_per_seg_val = get_images_per_seg(validation_table)
 
-    # Convert model from json and initialize weights
-    master_model = model_from_json(model_arch)
-    model_weights = master_model.get_weights()
-
-    # Get shape of weights in each layer from model arch
-    model_shapes = []
-    for weight_arr in master_model.get_weights():
-        model_shapes.append(weight_arr.shape)
-
-    if model_weights_serialized:
-        # If warm start from previously trained model, set weights
-        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
-            model_weights_serialized, model_shapes)
-        master_model.set_weights(model_weights)
-
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
     validation_metrics = []; validation_loss = []
@@ -167,7 +155,6 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
         """.format(**locals()), ["bytea"])
 
     # Define the state for the model and loss/metric storage lists
-    serialized_weights = 
madlib_keras_serializer.serialize_nd_weights(model_weights)
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
     metrics_iters = []
 
@@ -244,6 +231,9 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
         validation_metrics_final = validation_loss_final = 'NULL'
         validation_table = 'NULL'
 
+    if warm_start:
+        plpy.execute("DROP TABLE {0}, {1}".format
+                     (model, fit_validator.output_summary_model_table))
     create_output_summary_table = plpy.prepare("""
         CREATE TABLE {output_summary_model_table} AS
         SELECT
@@ -302,6 +292,30 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
+def get_initial_weights(model_table, model_arch_result, warm_start):
+    """
+        If warm_start is True, return back initial weights from model table.
+        If warm_start is False, first try to get the weights from model_arch
+        table, if no weights are defined there, randomly initialize it using
+        keras.
+        @args:
+            @param model_table: Output model table passed in to fit.
+            @param model_arch_result: Dict containing model architecture info.
+            @param warm_start: Boolean flag indicating warm start or not.
+    """
+    if warm_start:
+        serialized_weights = plpy.execute("""
+            SELECT model_data FROM {0}
+        """.format(model_table))[0]['model_data']
+    else:
+        serialized_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
+        if not serialized_weights:
+            master_model = model_from_json(
+                model_arch_result[ModelArchSchema.MODEL_ARCH])
+            serialized_weights = madlib_keras_serializer.serialize_nd_weights(
+                master_model.get_weights())
+    return serialized_weights
+
 def get_source_summary_table_dict(fit_validator):
     source_summary = plpy.execute("""
             SELECT
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 6b7b0c0..ccc069f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -39,6 +39,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     gpus_per_host           INTEGER,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
     name                    VARCHAR,
     description             VARCHAR
 ) RETURNS VOID AS $$
@@ -59,9 +60,10 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     gpus_per_host           INTEGER,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
     name                    VARCHAR
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, $11, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, $11, $12, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -75,9 +77,10 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     num_iterations          INTEGER,
     gpus_per_host           INTEGER,
     validation_table        VARCHAR,
-    metrics_compute_frequency  INTEGER
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
$10, $11, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -91,9 +94,26 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     fit_params              VARCHAR,
     num_iterations          INTEGER,
     gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER
+) RETURNS VOID AS $$
+SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 
NULL, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+    m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
+    source_table            VARCHAR,
+    model                   VARCHAR,
+    model_arch_table        VARCHAR,
+    model_arch_id           INTEGER,
+    compile_params          VARCHAR,
+    fit_params              VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
     validation_table        VARCHAR
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
NULL, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, 
NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -107,7 +127,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     num_iterations          INTEGER,
     gpus_per_host           INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, 
NULL, NULL, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, 
NULL, NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
@@ -120,7 +140,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     fit_params              VARCHAR,
     num_iterations          INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, 0, NULL, 
NULL, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, 0, NULL, 
NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 4ab5d45..56b4f2f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -213,7 +213,7 @@ class PredictInputValidator(InputValidator):
 class FitInputValidator:
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, dependent_varname, independent_varname,
-                 num_iterations, metrics_compute_frequency):
+                 num_iterations, metrics_compute_frequency, warm_start):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
@@ -221,6 +221,7 @@ class FitInputValidator:
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
         self.metrics_compute_frequency = metrics_compute_frequency
+        self.warm_start = warm_start
         self.num_iterations = num_iterations
         self.source_summary_table = None
         if self.source_table:
@@ -265,6 +266,7 @@ class FitInputValidator:
         cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME,
             NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
             'dependent_varname', 'independent_varname'], self.module_name)
+
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
         validate_dependent_var_for_minibatch(self.source_table,
@@ -272,11 +274,13 @@ class FitInputValidator:
 
         self._validate_validation_table()
 
-        # Validate model arch table's schema.
         input_tbl_valid(self.model_arch_table, self.module_name)
-        # Validate output tables
-        output_tbl_valid(self.output_model_table, self.module_name)
-        output_tbl_valid(self.output_summary_model_table, self.module_name)
+        if self.warm_start:
+            input_tbl_valid(self.output_model_table, self.module_name)
+            input_tbl_valid(self.output_summary_model_table, self.module_name)
+        else:
+            output_tbl_valid(self.output_model_table, self.module_name)
+            output_tbl_valid(self.output_summary_model_table, self.module_name)
 
 
     def _validate_validation_table(self):
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 847418f..df3935a 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -265,7 +265,9 @@ SELECT madlib_keras_fit(
     NULL,
     NULL,
     1,
-    'model name', 'model desc');
+    NULL,
+    'model name',
+    'model desc');
 
 SELECT assert(
     source_table = 'cifar_10_sample_batched' AND
@@ -421,7 +423,7 @@ SELECT madlib_keras_fit(
     NULL,
     NULL,
     NULL,
-    'model name', 'model desc');
+    NULL, 'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
 SELECT madlib_keras_fit(
@@ -435,7 +437,7 @@ SELECT madlib_keras_fit(
     NULL,
     NULL,
     NULL,
-    'model name', 'model desc');
+    NULL, 'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
 SELECT madlib_keras_fit(
@@ -449,7 +451,7 @@ SELECT madlib_keras_fit(
     0,
     NULL,
     NULL,
-    'model name', 'model desc');
+    NULL, 'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
 SELECT madlib_keras_fit(
@@ -463,7 +465,7 @@ SELECT madlib_keras_fit(
     NULL,
     NULL,
     NULL,
-    'model name', 'model desc');
+    False, 'model name', 'model desc');
 
 DROP TABLE IF EXISTS keras_out, keras_out_summary;
 SELECT madlib_keras_fit(
@@ -477,7 +479,7 @@ SELECT madlib_keras_fit(
     NULL,
     NULL,
     NULL,
-    'model name', 'model desc');
+    False, 'model name', 'model desc');
 
 -- -- negative test case for passing non numeric y to fit
 -- induce failure by passing a non numeric column
@@ -854,100 +856,325 @@ SELECT madlib_keras_predict(
     'prob',
     0);
 
+-------------------- TRANSFER LEARNING and WARM START -----------------
+
+DROP TABLE IF EXISTS iris_data;
+CREATE TABLE iris_data(
+    id serial,
+    attributes numeric[],
+    class_text varchar
+);
+INSERT INTO iris_data(id, attributes, class_text) VALUES
+(1,ARRAY[5.1,3.5,1.4,0.2],'Iris-setosa'),
+(2,ARRAY[4.9,3.0,1.4,0.2],'Iris-setosa'),
+(3,ARRAY[4.7,3.2,1.3,0.2],'Iris-setosa'),
+(4,ARRAY[4.6,3.1,1.5,0.2],'Iris-setosa'),
+(5,ARRAY[5.0,3.6,1.4,0.2],'Iris-setosa'),
+(6,ARRAY[5.4,3.9,1.7,0.4],'Iris-setosa'),
+(7,ARRAY[4.6,3.4,1.4,0.3],'Iris-setosa'),
+(8,ARRAY[5.0,3.4,1.5,0.2],'Iris-setosa'),
+(9,ARRAY[4.4,2.9,1.4,0.2],'Iris-setosa'),
+(10,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(11,ARRAY[5.4,3.7,1.5,0.2],'Iris-setosa'),
+(12,ARRAY[4.8,3.4,1.6,0.2],'Iris-setosa'),
+(13,ARRAY[4.8,3.0,1.4,0.1],'Iris-setosa'),
+(14,ARRAY[4.3,3.0,1.1,0.1],'Iris-setosa'),
+(15,ARRAY[5.8,4.0,1.2,0.2],'Iris-setosa'),
+(16,ARRAY[5.7,4.4,1.5,0.4],'Iris-setosa'),
+(17,ARRAY[5.4,3.9,1.3,0.4],'Iris-setosa'),
+(18,ARRAY[5.1,3.5,1.4,0.3],'Iris-setosa'),
+(19,ARRAY[5.7,3.8,1.7,0.3],'Iris-setosa'),
+(20,ARRAY[5.1,3.8,1.5,0.3],'Iris-setosa'),
+(21,ARRAY[5.4,3.4,1.7,0.2],'Iris-setosa'),
+(22,ARRAY[5.1,3.7,1.5,0.4],'Iris-setosa'),
+(23,ARRAY[4.6,3.6,1.0,0.2],'Iris-setosa'),
+(24,ARRAY[5.1,3.3,1.7,0.5],'Iris-setosa'),
+(25,ARRAY[4.8,3.4,1.9,0.2],'Iris-setosa'),
+(26,ARRAY[5.0,3.0,1.6,0.2],'Iris-setosa'),
+(27,ARRAY[5.0,3.4,1.6,0.4],'Iris-setosa'),
+(28,ARRAY[5.2,3.5,1.5,0.2],'Iris-setosa'),
+(29,ARRAY[5.2,3.4,1.4,0.2],'Iris-setosa'),
+(30,ARRAY[4.7,3.2,1.6,0.2],'Iris-setosa'),
+(31,ARRAY[4.8,3.1,1.6,0.2],'Iris-setosa'),
+(32,ARRAY[5.4,3.4,1.5,0.4],'Iris-setosa'),
+(33,ARRAY[5.2,4.1,1.5,0.1],'Iris-setosa'),
+(34,ARRAY[5.5,4.2,1.4,0.2],'Iris-setosa'),
+(35,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(36,ARRAY[5.0,3.2,1.2,0.2],'Iris-setosa'),
+(37,ARRAY[5.5,3.5,1.3,0.2],'Iris-setosa'),
+(38,ARRAY[4.9,3.1,1.5,0.1],'Iris-setosa'),
+(39,ARRAY[4.4,3.0,1.3,0.2],'Iris-setosa'),
+(40,ARRAY[5.1,3.4,1.5,0.2],'Iris-setosa'),
+(41,ARRAY[5.0,3.5,1.3,0.3],'Iris-setosa'),
+(42,ARRAY[4.5,2.3,1.3,0.3],'Iris-setosa'),
+(43,ARRAY[4.4,3.2,1.3,0.2],'Iris-setosa'),
+(44,ARRAY[5.0,3.5,1.6,0.6],'Iris-setosa'),
+(45,ARRAY[5.1,3.8,1.9,0.4],'Iris-setosa'),
+(46,ARRAY[4.8,3.0,1.4,0.3],'Iris-setosa'),
+(47,ARRAY[5.1,3.8,1.6,0.2],'Iris-setosa'),
+(48,ARRAY[4.6,3.2,1.4,0.2],'Iris-setosa'),
+(49,ARRAY[5.3,3.7,1.5,0.2],'Iris-setosa'),
+(50,ARRAY[5.0,3.3,1.4,0.2],'Iris-setosa'),
+(51,ARRAY[7.0,3.2,4.7,1.4],'Iris-versicolor'),
+(52,ARRAY[6.4,3.2,4.5,1.5],'Iris-versicolor'),
+(53,ARRAY[6.9,3.1,4.9,1.5],'Iris-versicolor'),
+(54,ARRAY[5.5,2.3,4.0,1.3],'Iris-versicolor'),
+(55,ARRAY[6.5,2.8,4.6,1.5],'Iris-versicolor'),
+(56,ARRAY[5.7,2.8,4.5,1.3],'Iris-versicolor'),
+(57,ARRAY[6.3,3.3,4.7,1.6],'Iris-versicolor'),
+(58,ARRAY[4.9,2.4,3.3,1.0],'Iris-versicolor'),
+(59,ARRAY[6.6,2.9,4.6,1.3],'Iris-versicolor'),
+(60,ARRAY[5.2,2.7,3.9,1.4],'Iris-versicolor'),
+(61,ARRAY[5.0,2.0,3.5,1.0],'Iris-versicolor'),
+(62,ARRAY[5.9,3.0,4.2,1.5],'Iris-versicolor'),
+(63,ARRAY[6.0,2.2,4.0,1.0],'Iris-versicolor'),
+(64,ARRAY[6.1,2.9,4.7,1.4],'Iris-versicolor'),
+(65,ARRAY[5.6,2.9,3.6,1.3],'Iris-versicolor'),
+(66,ARRAY[6.7,3.1,4.4,1.4],'Iris-versicolor'),
+(67,ARRAY[5.6,3.0,4.5,1.5],'Iris-versicolor'),
+(68,ARRAY[5.8,2.7,4.1,1.0],'Iris-versicolor'),
+(69,ARRAY[6.2,2.2,4.5,1.5],'Iris-versicolor'),
+(70,ARRAY[5.6,2.5,3.9,1.1],'Iris-versicolor'),
+(71,ARRAY[5.9,3.2,4.8,1.8],'Iris-versicolor'),
+(72,ARRAY[6.1,2.8,4.0,1.3],'Iris-versicolor'),
+(73,ARRAY[6.3,2.5,4.9,1.5],'Iris-versicolor'),
+(74,ARRAY[6.1,2.8,4.7,1.2],'Iris-versicolor'),
+(75,ARRAY[6.4,2.9,4.3,1.3],'Iris-versicolor'),
+(76,ARRAY[6.6,3.0,4.4,1.4],'Iris-versicolor'),
+(77,ARRAY[6.8,2.8,4.8,1.4],'Iris-versicolor'),
+(78,ARRAY[6.7,3.0,5.0,1.7],'Iris-versicolor'),
+(79,ARRAY[6.0,2.9,4.5,1.5],'Iris-versicolor'),
+(80,ARRAY[5.7,2.6,3.5,1.0],'Iris-versicolor'),
+(81,ARRAY[5.5,2.4,3.8,1.1],'Iris-versicolor'),
+(82,ARRAY[5.5,2.4,3.7,1.0],'Iris-versicolor'),
+(83,ARRAY[5.8,2.7,3.9,1.2],'Iris-versicolor'),
+(84,ARRAY[6.0,2.7,5.1,1.6],'Iris-versicolor'),
+(85,ARRAY[5.4,3.0,4.5,1.5],'Iris-versicolor'),
+(86,ARRAY[6.0,3.4,4.5,1.6],'Iris-versicolor'),
+(87,ARRAY[6.7,3.1,4.7,1.5],'Iris-versicolor'),
+(88,ARRAY[6.3,2.3,4.4,1.3],'Iris-versicolor'),
+(89,ARRAY[5.6,3.0,4.1,1.3],'Iris-versicolor'),
+(90,ARRAY[5.5,2.5,4.0,1.3],'Iris-versicolor'),
+(91,ARRAY[5.5,2.6,4.4,1.2],'Iris-versicolor'),
+(92,ARRAY[6.1,3.0,4.6,1.4],'Iris-versicolor'),
+(93,ARRAY[5.8,2.6,4.0,1.2],'Iris-versicolor'),
+(94,ARRAY[5.0,2.3,3.3,1.0],'Iris-versicolor'),
+(95,ARRAY[5.6,2.7,4.2,1.3],'Iris-versicolor'),
+(96,ARRAY[5.7,3.0,4.2,1.2],'Iris-versicolor'),
+(97,ARRAY[5.7,2.9,4.2,1.3],'Iris-versicolor'),
+(98,ARRAY[6.2,2.9,4.3,1.3],'Iris-versicolor'),
+(99,ARRAY[5.1,2.5,3.0,1.1],'Iris-versicolor'),
+(100,ARRAY[5.7,2.8,4.1,1.3],'Iris-versicolor'),
+(101,ARRAY[6.3,3.3,6.0,2.5],'Iris-virginica'),
+(102,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
+(103,ARRAY[7.1,3.0,5.9,2.1],'Iris-virginica'),
+(104,ARRAY[6.3,2.9,5.6,1.8],'Iris-virginica'),
+(105,ARRAY[6.5,3.0,5.8,2.2],'Iris-virginica'),
+(106,ARRAY[7.6,3.0,6.6,2.1],'Iris-virginica'),
+(107,ARRAY[4.9,2.5,4.5,1.7],'Iris-virginica'),
+(108,ARRAY[7.3,2.9,6.3,1.8],'Iris-virginica'),
+(109,ARRAY[6.7,2.5,5.8,1.8],'Iris-virginica'),
+(110,ARRAY[7.2,3.6,6.1,2.5],'Iris-virginica'),
+(111,ARRAY[6.5,3.2,5.1,2.0],'Iris-virginica'),
+(112,ARRAY[6.4,2.7,5.3,1.9],'Iris-virginica'),
+(113,ARRAY[6.8,3.0,5.5,2.1],'Iris-virginica'),
+(114,ARRAY[5.7,2.5,5.0,2.0],'Iris-virginica'),
+(115,ARRAY[5.8,2.8,5.1,2.4],'Iris-virginica'),
+(116,ARRAY[6.4,3.2,5.3,2.3],'Iris-virginica'),
+(117,ARRAY[6.5,3.0,5.5,1.8],'Iris-virginica'),
+(118,ARRAY[7.7,3.8,6.7,2.2],'Iris-virginica'),
+(119,ARRAY[7.7,2.6,6.9,2.3],'Iris-virginica'),
+(120,ARRAY[6.0,2.2,5.0,1.5],'Iris-virginica'),
+(121,ARRAY[6.9,3.2,5.7,2.3],'Iris-virginica'),
+(122,ARRAY[5.6,2.8,4.9,2.0],'Iris-virginica'),
+(123,ARRAY[7.7,2.8,6.7,2.0],'Iris-virginica'),
+(124,ARRAY[6.3,2.7,4.9,1.8],'Iris-virginica'),
+(125,ARRAY[6.7,3.3,5.7,2.1],'Iris-virginica'),
+(126,ARRAY[7.2,3.2,6.0,1.8],'Iris-virginica'),
+(127,ARRAY[6.2,2.8,4.8,1.8],'Iris-virginica'),
+(128,ARRAY[6.1,3.0,4.9,1.8],'Iris-virginica'),
+(129,ARRAY[6.4,2.8,5.6,2.1],'Iris-virginica'),
+(130,ARRAY[7.2,3.0,5.8,1.6],'Iris-virginica'),
+(131,ARRAY[7.4,2.8,6.1,1.9],'Iris-virginica'),
+(132,ARRAY[7.9,3.8,6.4,2.0],'Iris-virginica'),
+(133,ARRAY[6.4,2.8,5.6,2.2],'Iris-virginica'),
+(134,ARRAY[6.3,2.8,5.1,1.5],'Iris-virginica'),
+(135,ARRAY[6.1,2.6,5.6,1.4],'Iris-virginica'),
+(136,ARRAY[7.7,3.0,6.1,2.3],'Iris-virginica'),
+(137,ARRAY[6.3,3.4,5.6,2.4],'Iris-virginica'),
+(138,ARRAY[6.4,3.1,5.5,1.8],'Iris-virginica'),
+(139,ARRAY[6.0,3.0,4.8,1.8],'Iris-virginica'),
+(140,ARRAY[6.9,3.1,5.4,2.1],'Iris-virginica'),
+(141,ARRAY[6.7,3.1,5.6,2.4],'Iris-virginica'),
+(142,ARRAY[6.9,3.1,5.1,2.3],'Iris-virginica'),
+(143,ARRAY[5.8,2.7,5.1,1.9],'Iris-virginica'),
+(144,ARRAY[6.8,3.2,5.9,2.3],'Iris-virginica'),
+(145,ARRAY[6.7,3.3,5.7,2.5],'Iris-virginica'),
+(146,ARRAY[6.7,3.0,5.2,2.3],'Iris-virginica'),
+(147,ARRAY[6.3,2.5,5.0,1.9],'Iris-virginica'),
+(148,ARRAY[6.5,3.0,5.2,2.0],'Iris-virginica'),
+(149,ARRAY[6.2,3.4,5.4,2.3],'Iris-virginica'),
+(150,ARRAY[5.9,3.0,5.1,1.8],'Iris-virginica');
+
+DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_summary;
+SELECT training_preprocessor_dl('iris_data',         -- Source table
+                                'iris_data_packed',  -- Output table
+                                'class_text',        -- Dependent variable
+                                'attributes'         -- Independent variable
+                                );
+
+DROP TABLE IF EXISTS iris_model_arch;
+-- NOTE: The seed is set to 0 for every layer.
+SELECT load_keras_model('iris_model_arch',  -- Output table,
+$$
+{
+"class_name": "Sequential",
+"keras_version": "2.1.6",
+"config":
+    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "dtype": "float32", "activation": "relu", 
"trainable": true,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
+    "activity_regularizer": null}}, {"class_name": "Dense",
+    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "activation": "relu", "trainable": true, 
"kernel_regularizer": null,
+    "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, 
"use_bias": true,
+    "activity_regularizer": null}}, {"class_name": "Dense", "config": 
{"kernel_initializer":
+    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
+    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
+    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
+    "trainable": true, "kernel_regularizer": null, "bias_initializer": 
{"class_name": "Zeros",
+    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
+    "backend": "tensorflow"}
+$$
+);
+
+DROP TABLE IF EXISTS iris_model, iris_model_summary;
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                        'iris_model',          -- model output table
+                        'iris_model_arch',  -- model arch table
+                         1,                    -- model arch id
+                         $$ loss='categorical_crossentropy', optimizer='adam', 
metrics=['accuracy'] $$,  -- compile_params
+                         $$ batch_size=5, epochs=3 $$,  -- fit_params
+                         5,                    -- num_iterations
+                         NULL, NULL,
+                         1 -- metrics_compute_frequency
+                        );
+
+-- Test that our code is indeed learning something and not broken. The loss
+-- from the first iteration should be less than the 5th, while the accuracy
+-- must be greater.
+SELECT assert(
+  array_upper(training_loss, 1) = 5 AND
+  array_upper(training_metrics, 1) = 5,
+  'metrics compute frequency must be 1.')
+FROM iris_model_summary;
 
--- Test cases for transfer learning
--- 1. Create a model arch table with weights all set to 0.008. 0.008 is just a
--- random number we chose after a few experiments so that we can 
deterministically
--- assert the loss and metric values reported by madlib_keras_fit.
--- 2. Run keras fit and then update the model arch table with the output of 
the keras
--- fit run.
-CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID 
AS $$
-from keras.layers import *
-from keras import Sequential
-import numpy as np
-import plpy
-
-model = Sequential()
-model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', 
input_shape=(32,32,3,)))
-model.add(MaxPooling2D(pool_size=(2, 2)))
-model.add(Dropout(0.25))
-model.add(Flatten())
-model.add(Dense(2, activation='softmax'))
-
-# we don't really need to get the weights from the model and the flatten them 
since
-# we are using np.ones_like to replace all the weights with a constant.
-# We want to keep the flatten code and the concatenation code just for 
reference
-weights = model.get_weights()
-weights_flat = [ w.flatten() for w in weights ]
-weights1d = np.array([j for sub in weights_flat for j in sub])
-# Adjust weights so that the learning for the first iteration can be 
deterministic
-# 0.008 is just a random number we chose after a few experiments
-weights1d = np.ones_like(weights1d)*0.008
-weights_bytea = weights1d.tostring()
-
-model_config = model.to_json()
-
-plan1 = plpy.prepare("""SELECT load_keras_model(
-                        'test_keras_model_arch_table',
-                        $1, $2)
-                    """, ['json','bytea'])
-plpy.execute(plan1, [model_config, weights_bytea])
-
-$$ LANGUAGE plpythonu VOLATILE;
-
-DROP TABLE IF EXISTS test_keras_model_arch_table;
-SELECT create_model_arch_transfer_learning();
+SELECT assert(
+  training_loss[5]-training_loss[1] < 0 AND
+  training_metrics[5]-training_metrics[1] > 0,
+    'The loss and accuracy should have improved with more iterations.'
+)
+FROM iris_model_summary;
+
+-- Make a copy of the loss and metrics array, to compare it with runs after
+-- warm start and transfer learning.
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT training_loss_final, training_metrics_final
+FROM iris_model_summary;
+
+-- Duplicate the architecture, but note that trainable is set to FALSE.
+-- This is to ensure we don't learn anything new, that would help us
+-- deterministically assert the accuracy and loss after transfer learning
+-- and warm start.
+SELECT load_keras_model('iris_model_arch',  -- Output table,
+$$
+{
+"class_name": "Sequential",
+"keras_version": "2.1.6",
+"config":
+    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "dtype": "float32", "activation": "relu",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": 
true,
+    "activity_regularizer": null}}, {"class_name": "Dense",
+    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
+    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": 
"fan_avg"}},
+    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
+    "bias_constraint": null, "activation": "relu",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 10, "use_bias": true, "activity_regularizer": 
null}},
+    {"class_name": "Dense", "config": {"kernel_initializer":
+    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", 
"scale": 1.0,
+    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": 
null,
+    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
+    "trainable": false,
+    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
+    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": 
null}}],
+    "backend": "tensorflow"}
+$$
+);
+-- Copy weights that were learnt from the previous run, for transfer
+-- learning. Copy it now, because using warm_start will overwrite it.
+UPDATE iris_model_arch set model_weights = (select model_data from iris_model) 
 WHERE model_id = 2;
+
+-- Warm start test
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                               'iris_model',          -- model output table
+                               'iris_model_arch',  -- model arch table
+                                2,                    -- model arch id
+                                $$ loss='categorical_crossentropy', 
optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
+                                $$ batch_size=5, epochs=3 $$,  -- fit_params
+                                2,                    -- num_iterations,
+                                NULL, NULL, 1,
+                                true -- warm start
+                              );
 
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_saved_out',
-    'test_keras_model_arch_table',
-    1,
-    $$ optimizer=SGD(lr=0.001, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    1);
-SELECT training_loss_final FROM keras_saved_out_summary;
-
--- We want to keep this select in case the assert fails and we need
--- to know the actual values in the table without re running the entire test
-\x
-select * from keras_saved_out_summary;
-\x
-
--- This assert is a work in progress (we are hoping that these asserts will 
not be flaky).
--- We want to be able to assert that the loss/metric for the first iteration is
--- deterministic if we set weights using the load_keras function. Although we
--- have seen that the loss/metric values are different in the 3rd/4th decimal
--- every time we run fit after loading the weights.
-
--- TODO https://github.com/apache/madlib/pull/399#discussion_r288336557
--- Might be a cleaner assert if we can assert the weights themselves.
--- For instance, if we use weights1d = np.ones_like(weights1d) instead of
--- weights1d = np.ones_like(weights1d)*0.008, and freeze the first layer,
--- then even after multiple iterations the weights in the first layer should 
all
--- be 1. Look at How can I "freeze" Keras layers?
--- section in https://keras.io/getting-started/faq/ for how to freeze layers.
-SELECT assert(abs(training_loss_final - 0.6) < 0.1 AND
-              abs(training_metrics_final - 0.4) < 0.1,
-       'Transfer learning test failed.')
-FROM keras_saved_out_summary;
-DROP FUNCTION create_model_arch_transfer_learning();
-
-UPDATE test_keras_model_arch_table SET model_weights = model_data FROM 
keras_saved_out WHERE model_id = 1;
-DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
-SELECT madlib_keras_fit(
-    'cifar_10_sample_batched',
-    'keras_saved_out',
-    'test_keras_model_arch_table',
-    1,
-    $$ optimizer=SGD(lr=0.001, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['mae']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
-    3);
-SELECT training_loss_final, training_metrics_final FROM 
keras_saved_out_summary;
+SELECT assert(
+  array_upper(training_loss, 1) = 2 AND
+  array_upper(training_metrics, 1) = 2,
+  'metrics compute frequency must be 1.')
+FROM iris_model_summary;
 
---assert training loss and metric deterministic
-SELECT assert(abs(training_loss_final - 0.64) < 0.01 AND
-              abs(training_metrics_final - 0.47) < 0.01,
-       'Transfer learning test failed.')
-FROM keras_saved_out_summary;
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-10 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'warm start test failed because training loss and metrics don''t match the 
expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_model_summary AS second;
+
+-- Transfer learning test
+DROP TABLE IF EXISTS iris_model_transfer, iris_model_transfer_summary;
+SELECT madlib_keras_fit('iris_data_packed',   -- source table
+                               'iris_model_transfer',          -- model output 
table
+                               'iris_model_arch',  -- model arch table
+                                2,                    -- model arch id
+                                $$ loss='categorical_crossentropy', 
optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
+                                $$ batch_size=5, epochs=3 $$,  -- fit_params
+                                2,
+                                NULL, NULL, 1
+                              );
+
+SELECT assert(
+  array_upper(training_loss, 1) = 2 AND
+  array_upper(training_metrics, 1) = 2,
+  'metrics compute frequency must be 1.')
+FROM iris_model_transfer_summary;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-10 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'Transfer learning test failed because training loss and metrics don''t 
match the expected value.')
+FROM iris_model_first_run AS first, iris_model_transfer_summary AS second;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 20546a1..8253eaa 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -864,28 +864,28 @@ class 
MadlibKerasFitInputValidatorTestCase(unittest.TestCase):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table',
-            'dep_varname', 'independent_varname', 5, None)
+            'dep_varname', 'independent_varname', 5, None, False)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table',
-            'dep_varname', 'independent_varname', 5, 3)
+            'dep_varname', 'independent_varname', 5, 3, False)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table',
-            'dep_varname', 'independent_varname', 5, 0)
+            'dep_varname', 'independent_varname', 5, 0, False)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
         self.subject.FitInputValidator._validate_input_args = Mock()
         obj = self.subject.FitInputValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table',
-            'dep_varname', 'independent_varname', 5, 6)
+            'dep_varname', 'independent_varname', 5, 6, False)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
 class PredictInputValidatorTestCases(unittest.TestCase):

Reply via email to