reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r538744805
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -337,183 +376,308 @@ class FitMultipleModel(): local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): - custom_fn_names.append(local_loss) - custom_fn_mst_idx.append(mst_idx) + custom_fn_names.add(local_loss) + custom_msts.append(mst) if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): - custom_fn_names.append(local_metric) - custom_fn_mst_idx.append(mst_idx) - - if len(custom_fn_names) > 0: - # Pass only unique custom_fn_names to query from object table - custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names))) - for mst_idx in custom_fn_mst_idx: - self.msts[mst_idx][self.object_map_col] = custom_fn_object_map - - def create_mst_schedule_table(self, mst_row): - mst_temp_query = """ - CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl} - ({self.model_id_col} INTEGER, - {self.compile_params_col} VARCHAR, - {self.fit_params_col} VARCHAR, - {dist_key_col} INTEGER, - {self.mst_key_col} INTEGER, - {self.object_map_col} BYTEA) - """.format(dist_key_col=dist_key_col, **locals()) - plpy.execute(mst_temp_query) - for mst, dist_key in zip(mst_row, self.dist_keys): - if mst: - model_id = mst[self.model_id_col] - compile_params = mst[self.compile_params_col] - fit_params = mst[self.fit_params_col] - mst_key = mst[self.mst_key_col] - object_map = mst[self.object_map_col] - else: - model_id = "NULL" - compile_params = "NULL" - fit_params = "NULL" - mst_key = "NULL" - object_map = None - mst_insert_query = plpy.prepare( - """ - INSERT INTO {self.mst_current_schedule_tbl} - VALUES ({model_id}, - $madlib${compile_params}$madlib$, - $madlib${fit_params}$madlib$, - {dist_key}, - {mst_key}, - $1) - """.format(**locals()), ["BYTEA"]) - plpy.execute(mst_insert_query, [object_map]) - - def create_model_output_table(self): - output_table_create_query = """ - CREATE TABLE {self.model_output_table} - ({self.mst_key_col} INTEGER PRIMARY KEY, - {self.model_weights_col} BYTEA, - {self.model_arch_col} JSON) - """.format(self=self) - plpy.execute(output_table_create_query) - self.initialize_model_output_and_info() + custom_fn_names.add(local_metric) + custom_msts.append(mst) + + self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names) + + for mst in custom_msts: + mst[self.object_map_col] = self.custom_fn_object_map + + self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts } + + def init_schedule_tbl(self): + self.prev_dist_key_col = '__prev_dist_key__' + mst_key_list = '[' + ','.join(self.all_mst_keys) + ']' + + create_sched_query = """ + CREATE TABLE {self.schedule_tbl} AS + WITH map AS + (SELECT + unnest(ARRAY{mst_key_list}) {self.mst_key_col}, + unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col} + ) + SELECT + map.{self.mst_key_col}, + {self.model_id_col}, + map.{self.dist_key_col} AS {self.prev_dist_key_col}, + map.{self.dist_key_col} + FROM map LEFT JOIN {self.model_selection_table} + USING ({self.mst_key_col}) + DISTRIBUTED BY ({self.dist_key_col}) + """.format(self=self, mst_key_list=mst_key_list) + DEBUG.plpy.execute(create_sched_query) + + def rotate_schedule_tbl(self): + if not hasattr(self, 'rotate_schedule_plan'): + self.next_schedule_tbl = unique_string('next_schedule') + rotate_schedule_tbl_query = """ + CREATE TABLE {self.next_schedule_tbl} AS + SELECT + {self.mst_key_col}, + {self.model_id_col}, + {self.dist_key_col} AS {self.prev_dist_key_col}, + COALESCE( + LEAD({self.dist_key_col}) + OVER(ORDER BY {self.dist_key_col}), + FIRST_VALUE({self.dist_key_col}) + OVER(ORDER BY {self.dist_key_col}) + ) AS {self.dist_key_col} + FROM {self.schedule_tbl}; Review comment: I'm not sure whether the DISTRIBUTE BY in the `rotate_schedule_tbl` actually matters... if I recall, in the hop query, there is one Redistribute motion based on the dist key and another one based on the prev dist key... so I left it off figuring either way this small amount of data is going to have to move. It's the movement of the other tables (input & output) we care about since they're large. I haven't checked today, but it's possible the schedule table will just get broadcasted anyway, since it's much smaller than the model_input and model_output tables. But I am going to add a DISTRIBUTED BY `prev_dist_key_col`, because that seems like the safest of the 3 options, in case the planner or orca gets changed in the future and it starts picking a bad plan... choosing either `prev_dist_key_col` or `dist_key_col` at least seems better than letting it randomly distribute. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org