This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f10bc8  DL: Modify multi-fit warm start to accept non-matching 
mst&model tables
5f10bc8 is described below

commit 5f10bc8e72e88986cd109745dddec672fdaa1d84
Author: Orhan Kislal <okis...@apache.org>
AuthorDate: Tue Jan 7 19:36:34 2020 -0500

    DL: Modify multi-fit warm start to accept non-matching mst&model tables
    
    JIRA: MADLIB-1400 #resolve
    
    The warm start enforced that the model table had to have a tuple for each
    mst_key in the mst table for warm start. This commit relaxes this 
requirement
    so that users can add as well as substract mst keys throughtout their
    AutoML progress.
    
    Closes #466
---
 .../madlib_keras_fit_multiple_model.py_in          | 70 ++++++++++++++--------
 .../deep_learning/madlib_keras_validator.py_in     |  7 ---
 .../test/madlib_keras_transfer_learning.sql_in     | 24 +++++---
 3 files changed, 60 insertions(+), 41 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 5ce555a..273321e 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -162,8 +162,8 @@ class FitMultipleModel():
         random.shuffle(self.msts_for_schedule)
         self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
         self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
-        if not self.warm_start:
-            self.create_model_output_table()
+
+        self.create_model_output_table()
         self.weights_to_update_tbl = unique_string(desp='weights_to_update')
         self.fit_multiple_model()
         reset_cuda_env(original_cuda_env)
@@ -274,12 +274,26 @@ class FitMultipleModel():
             plpy.execute(mst_insert_query)
 
     def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
+        warm_start_msts = []
+        if self.warm_start:
+            plpy.execute(""" DELETE FROM {self.model_output_table}
+                WHERE {self.mst_key_col} NOT IN (
+                    SELECT {self.mst_key_col} FROM 
{self.model_selection_table})
+                """.format(self=self))
+            warm_start_msts = plpy.execute(
+                """ SELECT array_agg({0}) AS a FROM {1}
+                """.format(self.mst_key_col, self.model_output_table))[0]['a']
+            plpy.execute("DROP TABLE {0}".format(self.model_info_table))
+
+        else:
+            output_table_create_query = """
+                                        CREATE TABLE {self.model_output_table}
+                                        ({self.mst_key_col} INTEGER PRIMARY 
KEY,
+                                         {self.model_weights_col} BYTEA,
+                                         {self.model_arch_col} JSON)
+                                        """.format(self=self)
+            plpy.execute(output_table_create_query)
+
         info_table_create_query = """
                                   CREATE TABLE {self.model_info_table}
                                   ({self.mst_key_col} INTEGER PRIMARY KEY,
@@ -300,39 +314,32 @@ class FitMultipleModel():
                                    validation_loss DOUBLE PRECISION[])
                                """.format(self=self)
 
-        plpy.execute(output_table_create_query)
         plpy.execute(info_table_create_query)
         for mst in self.msts:
             model_arch, model_weights = 
get_model_arch_weights(self.model_arch_table,
                                                                
mst[self.model_id_col])
+
+
+            # If warm start is enabled, weights from transfer learning cannot 
be
+            # used, even if a particular model doesn't have warm start weigths.
+            if self.warm_start:
+                model_weights = None
+
             serialized_weights = get_initial_weights(self.model_output_table,
                                                      model_arch,
                                                      model_weights,
-                                                     False,
+                                                     mst['mst_key'] in 
warm_start_msts,
                                                      self.use_gpus,
                                                      
self.accessible_gpus_for_seg
                                                      )
-            model = model_from_json(model_arch)
-            serialized_state = model_weights if model_weights else \
-                
madlib_keras_serializer.serialize_nd_weights(model.get_weights())
-
             model_size = sys.getsizeof(serialized_weights) / 1024.0
+
             metrics_list = get_metrics_from_compile_param(
                 mst[self.compile_params_col])
             is_metrics_specified = True if metrics_list else False
             metrics_type = 'ARRAY{0}'.format(
                 metrics_list) if is_metrics_specified else 'NULL'
-            output_table_insert_query = """
-                                INSERT INTO {self.model_output_table}(
-                                    {self.mst_key_col}, 
{self.model_weights_col},
-                                    {self.model_arch_col})
-                                VALUES ({mst_key}, $1, $2)
-                                   """.format(self=self,
-                                              mst_key=mst[self.mst_key_col])
-            output_table_insert_query_prepared = plpy.prepare(
-                output_table_insert_query, ["bytea", "json"])
-            plpy.execute(output_table_insert_query_prepared, [
-                         serialized_state, model_arch])
+
             info_table_insert_query = """
                     INSERT INTO {self.model_info_table}({self.mst_key_col},
                                 {self.model_id_col}, {self.compile_params_col},
@@ -352,6 +359,19 @@ class FitMultipleModel():
                            metrics_type=metrics_type)
             plpy.execute(info_table_insert_query)
 
+            if not mst['mst_key'] in warm_start_msts:
+                output_table_insert_query = """
+                                    INSERT INTO {self.model_output_table}(
+                                        {self.mst_key_col}, 
{self.model_weights_col},
+                                        {self.model_arch_col})
+                                    VALUES ({mst_key}, $1, $2)
+                                       """.format(self=self,
+                                                  
mst_key=mst[self.mst_key_col])
+                output_table_insert_query_prepared = plpy.prepare(
+                    output_table_insert_query, ["bytea", "json"])
+                plpy.execute(output_table_insert_query_prepared, [
+                             serialized_weights, model_arch])
+
     def create_model_summary_table(self):
         if self.warm_start:
             plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index ad14087..37a2e25 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -437,13 +437,6 @@ class FitMultipleInputValidator(FitCommonValidator):
                                                         
accessible_gpus_for_seg,
                                                         self.module_name)
 
-        if warm_start:
-            mst_count = plpy.execute("SELECT count(*) FROM 
{0}".format(model_selection_table))[0]['count']
-            warm_count = plpy.execute("SELECT count(*) FROM 
{0}".format(output_model_table))[0]['count']
-
-            _assert(mst_count <= warm_count,
-                "{self.module_name} error: Model table and mst table do not 
match".format(self=self))
-
 class MstLoaderInputValidator():
     def __init__(self,
                  model_arch_table,
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 3c970a5..d17ea20 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -224,20 +224,18 @@ SELECT madlib_keras_fit_multiple_model(
   FALSE, NULL, 1,
   TRUE);
 
-
 SELECT assert(
-  abs(first.training_loss_final-second.training_loss_final) < 1e-6,
-  'The loss should not change for mst_key 4 since it has been removed from 
mst_table')
-FROM iris_model_first_run AS first, iris_multiple_model_info AS second
-WHERE first.mst_key = second.mst_key AND second.mst_key = 4;
+  4 NOT IN (SELECT mst_key FROM iris_multiple_model),
+  'mst_key 4 should not be in the model table since it has been removed from 
mst_table');
 
-INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params,
-    'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1;
+SELECT assert(
+  4 NOT IN (SELECT mst_key FROM iris_multiple_model_info),
+  'mst_key 4 should not be in the info table since it has been removed from 
mst_table');
 
 INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
     'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;
 
-SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
+SELECT madlib_keras_fit_multiple_model(
   'iris_data_packed',
   'iris_multiple_model',
   'mst_table',
@@ -245,7 +243,15 @@ SELECT 
assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
   FALSE,
   NULL, 1,
   TRUE -- warm_start
-);$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
+);
+
+SELECT assert(
+  5 IN (SELECT mst_key FROM iris_multiple_model),
+  'mst_key 5 should be in the model table since it has been added to 
mst_table');
+
+SELECT assert(
+  5 IN (SELECT mst_key FROM iris_multiple_model_info),
+  'mst_key 5 should be in the info table since it has been added to 
mst_table');
 
 -- Transfer learning tests
 

Reply via email to