reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r538744805



##########
File path: 
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in 
compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' 
in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in 
builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in 
builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = 
query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE 
{self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = 
query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};

Review comment:
         I'm not sure whether the DISTRIBUTE BY in the `rotate_schedule_tbl` 
actually matters... if I recall, in the hop query, there is one Redistribute 
motion based on the dist key and another one based on the prev dist key... so I 
left it off figuring either way this small amount of data is going to have to 
move.  It's the movement of the other tables (input & output) we care about 
since they're large.  I haven't checked today, but it's possible the schedule 
table will just get broadcasted anyway, since it's much smaller than the 
model_input and model_output tables.  But I am going to add a DISTRIBUTED BY 
`prev_dist_key_col`, because that seems like the safest of the 3 options, in 
case the planner or orca gets changed in the future and it starts picking a bad 
plan... choosing either `prev_dist_key_col` or `dist_key_col` at least seems 
better than letting it randomly distribute.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to