reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r537883975
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -337,183 +376,308 @@ class FitMultipleModel(): local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): - custom_fn_names.append(local_loss) - custom_fn_mst_idx.append(mst_idx) + custom_fn_names.add(local_loss) + custom_msts.append(mst) if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): - custom_fn_names.append(local_metric) - custom_fn_mst_idx.append(mst_idx) - - if len(custom_fn_names) > 0: - # Pass only unique custom_fn_names to query from object table - custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names))) - for mst_idx in custom_fn_mst_idx: - self.msts[mst_idx][self.object_map_col] = custom_fn_object_map - - def create_mst_schedule_table(self, mst_row): - mst_temp_query = """ - CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl} - ({self.model_id_col} INTEGER, - {self.compile_params_col} VARCHAR, - {self.fit_params_col} VARCHAR, - {dist_key_col} INTEGER, - {self.mst_key_col} INTEGER, - {self.object_map_col} BYTEA) - """.format(dist_key_col=dist_key_col, **locals()) - plpy.execute(mst_temp_query) - for mst, dist_key in zip(mst_row, self.dist_keys): - if mst: - model_id = mst[self.model_id_col] - compile_params = mst[self.compile_params_col] - fit_params = mst[self.fit_params_col] - mst_key = mst[self.mst_key_col] - object_map = mst[self.object_map_col] - else: - model_id = "NULL" - compile_params = "NULL" - fit_params = "NULL" - mst_key = "NULL" - object_map = None - mst_insert_query = plpy.prepare( - """ - INSERT INTO {self.mst_current_schedule_tbl} - VALUES ({model_id}, - $madlib${compile_params}$madlib$, - $madlib${fit_params}$madlib$, - {dist_key}, - {mst_key}, - $1) - """.format(**locals()), ["BYTEA"]) - plpy.execute(mst_insert_query, [object_map]) - - def create_model_output_table(self): - output_table_create_query = """ - CREATE TABLE {self.model_output_table} - ({self.mst_key_col} INTEGER PRIMARY KEY, - {self.model_weights_col} BYTEA, - {self.model_arch_col} JSON) - """.format(self=self) - plpy.execute(output_table_create_query) - self.initialize_model_output_and_info() + custom_fn_names.add(local_metric) + custom_msts.append(mst) + + self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names) + + for mst in custom_msts: + mst[self.object_map_col] = self.custom_fn_object_map + + self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts } + + def init_schedule_tbl(self): + self.prev_dist_key_col = '__prev_dist_key__' + mst_key_list = '[' + ','.join(self.all_mst_keys) + ']' + + create_sched_query = """ + CREATE TABLE {self.schedule_tbl} AS + WITH map AS + (SELECT + unnest(ARRAY{mst_key_list}) {self.mst_key_col}, + unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col} + ) + SELECT + map.{self.mst_key_col}, + {self.model_id_col}, + map.{self.dist_key_col} AS {self.prev_dist_key_col}, + map.{self.dist_key_col} + FROM map LEFT JOIN {self.model_selection_table} + USING ({self.mst_key_col}) + DISTRIBUTED BY ({self.dist_key_col}) + """.format(self=self, mst_key_list=mst_key_list) + DEBUG.plpy.execute(create_sched_query) + + def rotate_schedule_tbl(self): + if not hasattr(self, 'rotate_schedule_plan'): + self.next_schedule_tbl = unique_string('next_schedule') + rotate_schedule_tbl_query = """ + CREATE TABLE {self.next_schedule_tbl} AS + SELECT + {self.mst_key_col}, + {self.model_id_col}, + {self.dist_key_col} AS {self.prev_dist_key_col}, + COALESCE( + LEAD({self.dist_key_col}) + OVER(ORDER BY {self.dist_key_col}), + FIRST_VALUE({self.dist_key_col}) + OVER(ORDER BY {self.dist_key_col}) + ) AS {self.dist_key_col} + FROM {self.schedule_tbl}; + """.format(self=self) + self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query) Review comment: Yes, I figure any time the same query is going to be executed many times in a loop, best practice is to prepare the statement once and then execute it many times... no need to prepare the same statement again each time. That reminds me... I was hoping to do the same for the Hop and UDA queries... just never got to it. I did include a TODO comment next to the UDA. If I get a chance, I'll update those... might make more of a difference there. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org