kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r553567263



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -391,12 +422,14 @@ def get_initial_weights(model_table, model_arch, 
serialized_weights, warm_start,
         will only be used for segment nodes.
         @args:
             @param model_table: Output model table passed in to fit.
-            @param model_arch_result: Dict containing model architecture info.
+            @param model_arch: Dict containing model architecture info.
             @param warm_start: Boolean flag indicating warm start or not.
     """
     if is_platform_pg():
+        # Use GPU's if they are enabled
         _ = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[0], None)
     else:
+        # We are on master, so never use GPU's

Review comment:
       This comment is only valid for gpdb. For postgres, there is no 
master/coordinator node so we always use GPUs if available whereas on gpdb we 
use gpus only on segment nodes. We should update the comment to be a bit more 
clear

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, 
dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last 
iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state

Review comment:
       Shouldn't we error out if either prev_serialized_weights or 
model_architecture is None ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -663,17 +717,20 @@ def get_state_to_return(segment_model, is_last_row, 
is_multiple_model, agg_image
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for 
madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = 
madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in 
updated_model_weights]
-            new_state = 
madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:

Review comment:
       Now that we don't get  `agg_image_count` from the state, do we still 
need to set `new_state = float(agg_image_count)` at line 678/735 ?  This might 
help simplify the if/elif condition as well
   

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -338,183 +372,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in 
compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' 
in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in 
builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in 
builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = 
query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE 
{self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = 
query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
 
-    def create_model_output_table_warm_start(self):
+        create_sched_query = """
+            CREATE {self.unlogged_table} TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        plpy_execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if self.rotate_schedule_tbl_plan is None:
+            rotate_schedule_tbl_query = """
+                CREATE {self.unlogged_table} TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl}
+                DISTRIBUTED BY ({self.prev_dist_key_col})
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = 
plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        plpy_execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM 
{self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        plpy_execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE {self.unlogged_table} TABLE 
{self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        DEBUG.plpy.info("Pre-initialized mst keys: 
{}".format(initialized_msts))
+
+        # We've already bulk loaded all of the models with user-specified 
weights.
+        #  For the rest of the models, we need to generate the weights for each
+        #  by initializing them with keras and adding them one row at a time.
+        #
+        # TODO:  In the future, we should probably move the weight 
initialization
+        #  into the transition function on the segments.  Here, we would just
+        #  bulk load everything with a single query (or 2, for the warm start 
case),
+        #  and leave the weights column as NULL for any model whose weights 
need
+        #  to be randomly initialized.  Then in fit_transition, if 
prev_weights is
+        #  NULL, and there is nothing in GD, it should just skip the call to
+        #  set_weights(), and keras will automatically initialize them during
+        #  model.from_json(model_arch).
+        #
+        #  This would be a very easy change for fit_multiple(), but might 
require
+        #   some more work to support fit().  All of the segments there need to
+        #   start with the same weights, so we'd at least have to pass a random
+        #   seed to the transition function for keras to use.  Or generate a 
seed
+        #   on the segments in some deterministic way that's the same for all.
+        for index, mst in enumerate(self.msts_for_schedule):
+            if mst is None:
+                continue
+
+            if mst['mst_key'] in initialized_msts:
+                continue  # skip if we've already loaded this mst
+
+            num_dist_keys = len(self.dist_keys)
+
+            if index < num_dist_keys:
+                dist_key = self.dist_keys[index]
+            else:  # For models that won't be trained on first hop
+                dist_key = self.extra_dist_keys[index - num_dist_keys]
+
+            model_arch = get_model_arch(self.model_arch_table, 
mst[self.model_id_col])
+            serialized_weights = get_initial_weights(None, model_arch, None, 
False,
+                                                     
self.accessible_gpus_for_seg)
+
+            DEBUG.plpy.info(

Review comment:
       For code readability, I think we can remove this debug plpy.info 
statement. We can add it back for debugging as and when needed
   

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -197,72 +202,104 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = 
unique_string(desp='mst_current_schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))

Review comment:
       For code readability, I think we can remove these two debug plpy.info 
statements. We can add them back for debugging as and when needed

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -62,23 +68,29 @@ class GD_STORE:
     def clear(GD):
         del GD[GD_STORE.SEGMENT_MODEL]
         del GD[GD_STORE.SESS]
+        if GD_STORE.AGG_IMAGE_COUNT in GD:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
 
 def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, 
custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if GD_STORE.SESS in GD:
+
+    if GD_STORE.SESS in GD :
+        if GD_STORE.AGG_IMAGE_COUNT not in GD:

Review comment:
       The logic to initialize agg_image_count is also in the `fit_transition` 
and `fit_multiple_transition_caching` functions. We may not have to initialize 
it here as well

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, 
dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last 
iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     GD = kwargs['GD']
+
+    trans_enter_time = time.time()
+
     device_name = 
get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], 
current_seg_id)
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
-                                                  
accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, 
compile_params,
-                                                  custom_function_map)
-    if not state:
-        agg_image_count = 0
-        set_model_weights(segment_model, prev_serialized_weights)
+        accessible_gpus_for_seg[current_seg_id],
+        segments_per_host,
+        model_architecture, compile_params,
+        custom_function_map)
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+
+    DEBUG.plpy_info("agg_image_count={}".format(agg_image_count))

Review comment:
       I think we can remove this debug plpy.info statement. We can add it back 
for debugging as and when needed

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -43,16 +47,17 @@ from utilities.utilities import is_platform_pg
 from utilities.utilities import get_seg_number
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import rename_table
+import utilities.debug as DEBUG
+from utilities.debug import plpy_prepare
+from utilities.debug import plpy_execute
 
-import json
-from collections import defaultdict
-import random
-import datetime
+DEBUG.timings_enabled = True

Review comment:
       We should set this to false before merging the code

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -27,6 +27,7 @@ from madlib_keras_model_selection import ModelSelectionSchema
 from keras_model_arch_table import ModelArchSchema
 from utilities.validate_args import table_exists, drop_tables, input_tbl_valid
 from utilities.validate_args import quote_ident
+from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME

Review comment:
       Changes to the automl code are now part of PR #526 right ? 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -520,21 +544,35 @@ def fit_transition(state, dependent_var, independent_var, 
dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last 
iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:

Review comment:
       Isn't it better to check for the actual data column rather than the 
shape column to assert for missing data?
   
   Also previously we supported the case when one of the rows would have 
missing x/y but now it will error out. Is this intentional ?

##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = 
unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = 
{0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else 
GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS 
{self.original_model_output_tbl};
+                                    CREATE TABLE 
{self.original_model_output_tbl} AS

Review comment:
       Sure makes sense. In that case, drop and create looks good




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to