This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 857049e  DL: Add optional parameters for multi model training
857049e is described below

commit 857049ef644fc5c9759c4f0d7c74dd62fc79f527
Author: Orhan Kislal <[email protected]>
AuthorDate: Thu Dec 5 13:13:54 2019 -0500

    DL: Add optional parameters for multi model training
    
    JIRA: MADLIB-1397
    
    This commit adds the following optional params:
    - metrics_compute_frequency
    - warm_start
    - name
    - description
    
    It also fixes a bug where the users CUDA env variable is overwritten before 
it can be saved.
    
    Closes #461
    
    Co-authored-by: Ekta Khanna <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       |   3 +-
 .../madlib_keras_fit_multiple_model.py_in          |  63 ++++++---
 .../madlib_keras_fit_multiple_model.sql_in         |  62 ++++++++-
 .../deep_learning/madlib_keras_validator.py_in     |  12 +-
 .../deep_learning/madlib_keras_wrapper.py_in       |   3 +-
 .../test/madlib_keras_model_selection.sql_in       |  38 ++++--
 .../test/madlib_keras_transfer_learning.sql_in     | 146 ++++++++++++++++++---
 7 files changed, 271 insertions(+), 56 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 702c288..5e35b46 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -306,7 +306,8 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
-def get_initial_weights(model_table, model_arch, serialized_weights, 
warm_start, gpus_per_host):
+def get_initial_weights(model_table, model_arch, serialized_weights, 
warm_start,
+                        gpus_per_host):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 4ce21dd..883ed22 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -27,6 +27,7 @@ from madlib_keras import get_initial_weights
 from madlib_keras import get_model_arch_weights
 from madlib_keras import get_segments_and_gpus
 from madlib_keras import get_source_summary_table_dict
+from madlib_keras import should_compute_metrics_this_iter
 from madlib_keras_helper import *
 from madlib_keras_model_selection import ModelSelectionSchema
 from madlib_keras_validator import *
@@ -73,7 +74,9 @@ Note that this function is disabled for Postgres.
 class FitMultipleModel():
     def __init__(self, schema_madlib, source_table, model_output_table,
                  model_selection_table, num_iterations,
-                 gpus_per_host=0, validation_table=None, **kwargs):
+                 gpus_per_host=0, validation_table=None,
+                 metrics_compute_frequency=None, warm_start=False, name="",
+                 description="", **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -90,6 +93,9 @@ class FitMultipleModel():
             self.model_summary_table = add_postfix(
                 model_output_table, '_summary')
         self.num_iterations = num_iterations
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -111,9 +117,18 @@ class FitMultipleModel():
             self.model_selection_table, self.model_selection_summary_table,
             mb_dep_var_col, mb_indep_var_col, self.num_iterations,
             self.model_info_table, self.mst_key_col, self.model_arch_table_col,
-            1, False)
+            self.metrics_compute_frequency, warm_start)
+        if self.metrics_compute_frequency is None:
+            self.metrics_compute_frequency = num_iterations
+        self.warm_start = bool(warm_start)
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
+        self.metrics_iters = []
+
+        original_cuda_env = None
+        if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+            original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+
         self.seg_ids_train, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
                 self.source_table)
@@ -138,26 +153,24 @@ class FitMultipleModel():
         self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
         self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
             gpus_per_host)
-        self.create_model_output_table()
+        if not self.warm_start:
+            self.create_model_output_table()
         self.weights_to_update_tbl = unique_string(desp='weights_to_update')
         self.fit_multiple_model()
+        reset_cuda_env(original_cuda_env)
 
     def fit_multiple_model(self):
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
-            original_cuda_env = None
-            if CUDA_VISIBLE_DEVICES_KEY in os.environ:
-                original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
             self.start_training_time = datetime.datetime.now()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
             self.insert_info_table()
             self.create_model_summary_table()
-        reset_cuda_env(original_cuda_env)
 
     def train_multiple_model(self):
         total_msts = len(self.msts_for_schedule)
-        for iter in range(self.num_iterations):
+        for iter in range(1, self.num_iterations+1):
             for mst_idx in range(total_msts):
                 mst_row = [self.grand_schedule[dist_key][mst_idx]
                            for dist_key in self.dist_keys]
@@ -167,12 +180,16 @@ class FitMultipleModel():
                 self.run_training()
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
-                    self.info_str = "\tTime for training in iteration {0}: {1} 
sec\n".format(
-                        iter, end_iteration - start_iteration)
-            self.info_str += "\tTraining set after iteration {0}:".format(iter)
-            self.evaluate_model(iter, self.source_table, True)
-            if self.validation_table:
-                self.evaluate_model(iter, self.validation_table, False)
+                    self.info_str = "\tTime for training in iteration {0}: {1} 
sec\n".format(iter,
+                                        end_iteration - start_iteration)
+            if should_compute_metrics_this_iter(iter,
+                                                self.metrics_compute_frequency,
+                                                self.num_iterations):
+                self.metrics_iters.append(iter)
+                self.info_str += "\tTraining set after iteration 
{0}:".format(iter)
+                self.evaluate_model(iter, self.source_table, True)
+                if self.validation_table:
+                    self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
 
     def evaluate_model(self, epoch, table, is_train):
@@ -246,7 +263,6 @@ class FitMultipleModel():
             plpy.execute(mst_insert_query)
 
     def create_model_output_table(self):
-
         output_table_create_query = """
                                     CREATE TABLE {self.model_output_table}
                                     ({self.mst_key_col} INTEGER PRIMARY KEY,
@@ -282,10 +298,8 @@ class FitMultipleModel():
                                                      model_arch,
                                                      model_weights,
                                                      False,
-                                                     self.gpus_per_host
-                                                     )
+                                                     self.gpus_per_host)
             model = model_from_json(model_arch)
-
             serialized_state = model_weights if model_weights else \
                 
madlib_keras_serializer.serialize_nd_weights(model.get_weights())
 
@@ -295,7 +309,6 @@ class FitMultipleModel():
             is_metrics_specified = True if metrics_list else False
             metrics_type = 'ARRAY{0}'.format(
                 metrics_list) if is_metrics_specified else 'NULL'
-
             output_table_insert_query = """
                                 INSERT INTO {self.model_output_table}(
                                     {self.mst_key_col}, 
{self.model_weights_col},
@@ -327,6 +340,8 @@ class FitMultipleModel():
             plpy.execute(info_table_insert_query)
 
     def create_model_summary_table(self):
+        if self.warm_start:
+            plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
         src_summary_dict = 
get_source_summary_table_dict(self.fit_validator_train)
         class_values = src_summary_dict['class_values']
         dep_vartype = src_summary_dict['dep_vartype']
@@ -344,6 +359,9 @@ class FitMultipleModel():
             class_values_str = 'ARRAY{0}::{1}'.format(class_values,
                                                       
src_summary_dict['class_values_type'])
             num_classes = len(class_values)
+        name = 'NULL' if self.name is None else 
'$MAD${0}$MAD$'.format(self.name)
+        descr = 'NULL' if self.description is None else 
'$MAD${0}$MAD$'.format(self.description)
+        metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
         class_values_colname = CLASS_VALUES_COLNAME
         dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
         normalizing_const_colname = NORMALIZING_CONST_COLNAME
@@ -359,13 +377,18 @@ class FitMultipleModel():
                     $MAD${independent_varname}$MAD$::TEXT AS 
independent_varname,
                     $MAD${self.model_arch_table}$MAD$::TEXT AS 
model_arch_table,
                     {self.num_iterations}::INTEGER AS num_iterations,
+                    {self.metrics_compute_frequency}::INTEGER AS 
metrics_compute_frequency,
+                    {self.warm_start} AS warm_start,
+                    {name}::TEXT AS name,
+                    {descr}::TEXT AS description,
                     '{self.start_training_time}'::TIMESTAMP AS 
start_training_time,
                     '{self.end_training_time}'::TIMESTAMP AS end_training_time,
                     '{self.version}'::TEXT AS madlib_version,
                     {num_classes}::INTEGER AS num_classes,
                     {class_values_str} AS {class_values_colname},
                     $MAD${dep_vartype}$MAD$::TEXT AS 
{dependent_vartype_colname},
-                    {norm_const}::{float32_sql_type} AS 
{normalizing_const_colname}
+                    {norm_const}::{float32_sql_type} AS 
{normalizing_const_colname},
+                    ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
             """.format(**locals())
         plpy.execute(update_query)
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index ea19523..433535d 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -35,7 +35,11 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
     gpus_per_host           INTEGER,
-    validation_table        VARCHAR
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
+    name                    VARCHAR,
+    description             VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     with AOControl(False):
@@ -48,9 +52,63 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
+    name                    VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
$6, $7, $8, $9, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
$6, $7, $8, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
$6, $7, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
$6, NULL, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
     gpus_per_host           INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, 
NULL, NULL, FALSE, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index e24b8bd..49b8934 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -394,7 +394,10 @@ class FitMultipleInputValidator(FitCommonValidator):
         self.msts, self.model_arch_table = query_model_configs(
             model_selection_table, model_selection_summary_table,
             mst_key_col, model_arch_table_col)
-        output_tbl_valid(model_info_table, self.module_name)
+        if warm_start:
+            input_tbl_valid(model_info_table, self.module_name)
+        else:
+            output_tbl_valid(model_info_table, self.module_name)
         super(FitMultipleInputValidator, self).__init__(source_table,
                                                         validation_table,
                                                         output_model_table,
@@ -407,9 +410,12 @@ class FitMultipleInputValidator(FitCommonValidator):
                                                         warm_start,
                                                         self.module_name)
 
+        if warm_start:
+            mst_count = plpy.execute("SELECT count(*) FROM 
{0}".format(model_selection_table))[0]['count']
+            warm_count = plpy.execute("SELECT count(*) FROM 
{0}".format(output_model_table))[0]['count']
 
-
-
+            _assert(mst_count <= warm_count,
+                "{self.module_name} error: Model table and mst table do not 
match".format(self=self))
 
 class MstLoaderInputValidator():
     def __init__(self,
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 73df519..c9511d9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -54,7 +54,8 @@ def reset_cuda_env(value):
     if value:
         set_cuda_env(value)
     else:
-        del os.environ[CUDA_VISIBLE_DEVICES_KEY]
+        if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+            del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
 def get_device_name_and_set_cuda_env(gpus_per_host, seg):
     if gpus_per_host > 0:
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 2a20467..adc771e 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -205,7 +205,8 @@ SELECT assert(
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
-        normalizing_const = 1,
+        normalizing_const = 1 AND
+        metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes 
in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
@@ -236,7 +237,10 @@ SELECT assert(
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
         dependent_vartype LIKE '%char%' AND
-        normalizing_const = 1,
+        normalizing_const = 1 AND
+        name IS NULL AND
+        description IS NULL AND
+        metrics_compute_frequency = 6,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || 
__to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
@@ -251,13 +255,13 @@ SELECT assert(
         metrics_type = '{accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 6 AND
-        array_upper(training_loss, 1) = 6 AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
         validation_metrics_final >= 0  AND
         validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 6 AND
-        array_upper(validation_loss, 1) = 6 AND
-        array_upper(metrics_elapsed_time, 1) = 6,
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || 
__to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
@@ -291,7 +295,12 @@ SELECT madlib_keras_fit_multiple_model(
        'iris_multiple_model',
        'mst_table_1row',
        3,
-       0
+       0,
+       NULL,
+       1,
+       FALSE,
+       'multi_model_name',
+       'multi_model_descr'
 );
 
 SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the 
number of msts.')
@@ -311,6 +320,13 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || 
__to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
+SELECT assert(
+        name = 'multi_model_name' AND
+        description = 'multi_model_descr' AND
+        metrics_compute_frequency = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || 
__to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
 SELECT assert(cnt = 1,
        'Keras Fit Multiple Output Info compile params validation failed. 
Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
@@ -336,9 +352,9 @@ SELECT assert(
         metrics_type = '{accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 3 AND
-        array_upper(training_loss, 1) = 3 AND
-        array_upper(metrics_elapsed_time, 1) = 3,
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || 
__to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 220569d..0ab09b7 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -126,6 +126,129 @@ FROM iris_model_first_run AS first, 
iris_model_transfer_summary AS second;
 m4_changequote(`<!', `!>')
 m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
 
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1,2],
+    ARRAY[
+        $$loss='categorical_crossentropy', 
optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$
+    ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+    training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+-- warm start for fit multiple model
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0,
+  NULL, 1,
+  TRUE -- warm_start
+);
+
+SELECT assert(
+  array_upper(training_loss, 1) = 3 AND
+  array_upper(training_metrics, 1) = 3,
+  'metrics compute frequency must be 1.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'warm start test failed because training loss and metrics don''t match the 
expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+
+-- warm start with different mst tables
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', 
optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$,
+        $$batch_size=15,epochs=1$$,
+        $$batch_size=20,epochs=1$$
+    ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+    training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+DELETE FROM mst_table WHERE mst_key = 4;
+
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1,
+  TRUE);
+
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss_final) < 1e-6,
+  'The loss should not change for mst_key 4 since it has been removed from 
mst_table')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND second.mst_key = 4;
+
+INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params,
+    'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
+    'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0,
+  NULL, 1,
+  TRUE -- warm_start
+);$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
+
+-- Transfer learning tests
+
 -- Load the same arch again so that we can compare transfer learning results
 SELECT load_keras_model('iris_model_arch',  -- Output table,
 $$
@@ -169,15 +292,9 @@ SELECT load_model_selection_table(
     ]
 );
 
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
-  'iris_data_packed',
-  'iris_multiple_model',
-  'mst_table',
-  3,
-  0
-);
+UPDATE iris_model_arch
+SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE 
mst_key=1)
+WHERE model_id = 1;
 
 DROP TABLE IF EXISTS iris_model_first_run;
 CREATE TABLE iris_model_first_run AS
@@ -185,10 +302,6 @@ SELECT mst_key, model_id, training_loss, training_metrics,
     training_loss_final, training_metrics_final
 FROM iris_multiple_model_info;
 
-UPDATE iris_model_arch
-SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE 
mst_key=1)
-WHERE model_id = 1;
-
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
   'iris_data_packed',
@@ -199,11 +312,8 @@ SELECT madlib_keras_fit_multiple_model(
 );
 
 SELECT assert(
-  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
-  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
-  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
-  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  (first.training_loss_final-second.training_loss_final) > 1e-6,
   'Transfer learning test failed because training loss and metrics don''t 
match the expected value.')
 FROM iris_model_first_run AS first, iris_multiple_model_info AS second
-WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+WHERE first.mst_key = second.mst_key AND first.model_id = 1;
 !>)

Reply via email to