kaknikhil commented on code in PR #594:
URL: https://github.com/apache/madlib/pull/594#discussion_r1107670385


##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -47,14 +48,18 @@ from graph_utils import validate_output_and_summary_tables
 
 def validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
                       edge_params, out_table, out_table_summary,
-                      grouping_cols_list, module_name):
+                      grouping_cols_list, warm_start, module_name):
     """
     Function to validate input parameters for wcc
     """
     validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, module_name)
-    _assert(not table_exists(out_table_summary),
-            "Graph {module_name}: Output summary table already 
exists!".format(**locals()))
+                          out_table, module_name, warm_start)
+    if not warm_start:

Review Comment:
   Shouldn't the summary table validation also happen in 
`validate_graph_coding` ?



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:

Review Comment:
   1. I think it will be a good idea to explain the workflow for wcc when 
warm_start is set to true vs when it's set to false. This could be added to the 
commit message, PR description and also as a comment in the python file. 
   2. Should we also update the design doc and user docs ?



##########
src/ports/postgres/modules/graph/wcc.sql_in:
##########
@@ -115,6 +117,17 @@ weakly connected components are generated for all data
 (single graph).
 @note Expressions are not currently supported for 'grouping_cols'.</dd>
 
+<dt>iteration_limit (optional)</dt>
+<dd>INTEGER, default: NULL. Maximum number of iterations to run wcc. This

Review Comment:
   We should explicitly call out all the tables that get created for various 
scenarios of iteration limit and nodes_to_update



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """
+            CREATE TABLE {newupdate} AS SELECT * FROM {out_table};
+        """.format(**locals())
+        msg_sql = """
+            CREATE TABLE {message} AS SELECT * FROM {out_table_message};
+        """.format(**locals())
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            new_update_sql += """

Review Comment:
   Maybe use `.format` instead of `+`. Also applies to other places where we 
use `+`



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -366,44 +406,61 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
         # found in the current iteration.
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-            plpy.execute(loop_sql.format(**locals()))
-
-            if grouping_cols:
-                nodes_to_update = plpy.execute("""
-                                    SELECT SUM(cnt) AS cnt_sum
-                                    FROM (
-                                        SELECT COUNT(*) AS cnt
-                                        FROM {toupdate}
-                                        GROUP BY {grouping_cols}
-                                    ) t
-                    """.format(**locals()))[0]["cnt_sum"]
-            else:
-                nodes_to_update = plpy.execute("""
-                                    SELECT COUNT(*) AS cnt FROM {toupdate}
-                                """.format(**locals()))[0]["cnt"]
+            nodes_to_update = 
plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"]
+            iteration_counter += 1
+
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
         plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
 
-    rename_table(schema_madlib, newupdate, out_table)
-    if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
-        plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO 
{vertex_id_in}".format(**locals()))
+    if not warm_start:
+        rename_table(schema_madlib, newupdate, out_table)
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO 
{vertex_id_in}".format(**locals()))
+    else:
+        plpy.execute("""
+            TRUNCATE TABLE {out_table};

Review Comment:
   Is there a performance reason for doing a truncate and insert rather than 
drop and rename ? If yes, we should add that as a comment here 



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 
'user_id', 2);

Review Comment:
   I think there's a minor bug. Consider the following scenario
   
   Add data
   ```
   test=# select * from "EDGE";
    src_node | dest_node | user_id
   ----------+-----------+---------
           1 |         2 |       1
           1 |         3 |       1
           6 |         3 |       1
           5 |         4 |       1
           5 |         8 |       1
           2 |         3 |       1
           7 |         6 |       1
           8 |         4 |       1
   (8 rows)
   
   Time: 4.080 ms
   test=# select * from vertex ;
    vertex_id
   -----------
            1
            2
            3
            4
            7
            8
            5
            6
   (8 rows)
   ```
   
   Run with no grouping col but an iteration limit of 1
   ```
   SELECT 
madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out',
 NULL, 1);
   ```
   
   Now without dropping any tables, run the same query again with warm start 
set to true
   ```
   SELECT 
madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out',
 NULL, 1, TRUE);
   ERROR:  spiexceptions.DuplicateTable: relation "wcc_out_message" already 
exists
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "weakly_connected_components", line 21, in <module>
       return wcc.wcc(**globals())
     PL/Python function "weakly_connected_components", line 432, in wcc
     PL/Python function "weakly_connected_components", line 1212, in 
rename_table
     PL/Python function "weakly_connected_components", line 1261, in 
__do_rename_and_get_new_name
   PL/Python function "weakly_connected_components"
   ```
   The interesting thing is that if you set the iteration limit to 4 for the 
second wcc query, it does not error out. I think that's because it takes 5 
iterations for the `nodes_to_update` to become 0. 
   
   We should also dd a test case for this scenario  (This is based on the 
assumption that it takes 5 iteration to update all the nodes)
   1. Run 2 iterations without warm start
   2. Run 2 more with warm start set to true 
   3. Run 1 or 2 more with warm start set to true.
   For the first two runs, we should at the very least assert that 
nodes_to_update > 0 and for the final run, we should assert the contents of the 
out table and also that nodes_to_update = 0
   



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """

Review Comment:
   I think we should initialize all these sql related variables to "" outside 
the if check. That way we won't ever run into `variable not defined` issues



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """
+            CREATE TABLE {newupdate} AS SELECT * FROM {out_table};
+        """.format(**locals())
+        msg_sql = """
+            CREATE TABLE {message} AS SELECT * FROM {out_table_message};
+        """.format(**locals())
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':

Review Comment:
   Just curious, why do we need this if check ? Why is there a need to rename 
to `{vertex_id}` ?



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -90,6 +96,8 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, 
edge_args,
     edge_params = extract_keyvalue_params(
         edge_args, params_types, default_args)
 
+    if iteration_limit is None or iteration_limit == 0:

Review Comment:
   Do we also need to prevent against <0 values ?



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 
'user_id', 2);

Review Comment:
   Other observations:
   1. There is a table that doesn't get dropped if grouping cols is passed in
   2. Run without an iteration limit so that eventually there are no nodes to 
update
   ```
   SELECT 
madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out',
 'user_id',0,TRUE);
   ```
   Now run with warm start set to true
   ```
   SELECT 
madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out',
 'user_id',2,TRUE);
   ERROR:  spiexceptions.UndefinedTable: relation "wcc_out_message" does not 
exist
   LINE 2: ...age15774726_1676600954_3919632__ AS SELECT * FROM wcc_out_me...
                                                                ^
   QUERY:
               CREATE TABLE __madlib_temp_message15774726_1676600954_3919632__ 
AS SELECT * FROM wcc_out_message;
   
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "weakly_connected_components", line 21, in <module>
       return wcc.wcc(**globals())
     PL/Python function "weakly_connected_components", line 340, in wcc
   PL/Python function "weakly_connected_components"
   ```
   Instead of this exception, we could print a message along the lines of "We 
have already updated all the nodes, nothing to do"
   3. If the user passes in -1 as the iteration_limit, the `nodes_to_update` 
column in out_summary table is always 1 which doesn't seem right.



##########
src/ports/postgres/modules/graph/graph_utils.py_in:
##########
@@ -74,14 +74,18 @@ def validate_output_and_summary_tables(model_out_table, 
module_name,
                 "Graph WCC: Output table {0} already 
exists.".format(out_table))
 
 def validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, func_name, **kwargs):
+                          out_table, func_name, warm_start = False, **kwargs):
     """
     Validates graph tables (vertex and edge) as well as the output table.
     """
     _assert(out_table and out_table.strip().lower() not in ('null', ''),
-            "Graph {func_name}: Invalid output table name!".format(**locals()))
-    _assert(not table_exists(out_table),
-            "Graph {func_name}: Output table already 
exists!".format(**locals()))
+                "Graph {func_name}: Invalid output table 
name!".format(**locals()))
+    if not warm_start:

Review Comment:
   Flipping this if check might improve code readability but I'll leave it up 
to you to decide



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -158,20 +166,21 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
     out_table_summary = ''
     if out_table:
         out_table_summary = add_postfix(out_table, "_summary")
+        out_table_message = add_postfix(out_table, "_message")
     grouping_cols_list = split_quoted_delimited_str(grouping_cols)
     validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
                       edge_params, out_table, out_table_summary,
-                      grouping_cols_list, 'Weakly Connected Components')
+                      grouping_cols_list, warm_start, 'Weakly Connected 
Components')
 
     vertex_view_sql = vertex_view_sql.format(**locals())
-    plpy.execute(vertex_view_sql)
 
-    sql = """
+    edge_view_sql = """
         CREATE VIEW {edge_view} AS
         SELECT {src} AS src, {dest} AS dest {grouping_sql}
-        FROM {edge_table}
+        FROM {edge_table};
         """.format(**locals())
-    plpy.execute(sql)
+
+    plpy.execute(vertex_view_sql + edge_view_sql)

Review Comment:
   We should add a comment explaining why we need to run all these sqls in the 
same plpy.execute. 



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 
'user_id', 2);

Review Comment:
   Test improvements/notes:
   1. This is a test for wcc with warm start and grouping cols right ? 
Shouldn't we also add a test with warm start and no grouping cols ?
   2. We should also add an assert for the output table and the summary table. 
That will help us in catching regressions(if any)
   3. Can there be any issues if the user upgrades madlib and then uses the new 
function with warm start set to true on the previous out tables ? We might have 
to explicitly call this out as well



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -235,77 +257,92 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
         join_grouping_cols = _check_groups(subq, distinct_grp_table, 
grouping_cols_list)
         group_by_clause_newupdate = ('{0}, 
{1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
+        select_grouping_cols = ',' + subq_prefixed_grouping_cols
+
+        if not warm_start:
+            new_update_sql = """

Review Comment:
   We should also mention in the commit message the reason behind combining all 
these sqls into one plpy execute



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -366,44 +406,61 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
         # found in the current iteration.
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-            plpy.execute(loop_sql.format(**locals()))
-
-            if grouping_cols:
-                nodes_to_update = plpy.execute("""
-                                    SELECT SUM(cnt) AS cnt_sum
-                                    FROM (
-                                        SELECT COUNT(*) AS cnt
-                                        FROM {toupdate}
-                                        GROUP BY {grouping_cols}
-                                    ) t
-                    """.format(**locals()))[0]["cnt_sum"]
-            else:
-                nodes_to_update = plpy.execute("""
-                                    SELECT COUNT(*) AS cnt FROM {toupdate}
-                                """.format(**locals()))[0]["cnt"]
+            nodes_to_update = 
plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"]
+            iteration_counter += 1
+
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
         plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
 
-    rename_table(schema_madlib, newupdate, out_table)
-    if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
-        plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO 
{vertex_id_in}".format(**locals()))
+    if not warm_start:
+        rename_table(schema_madlib, newupdate, out_table)
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO 
{vertex_id_in}".format(**locals()))
+    else:
+        plpy.execute("""
+            TRUNCATE TABLE {out_table};
+            INSERT INTO {out_table}
+            SELECT
+                {vertex_id} AS {vertex_id_in},
+                {component_id}
+                {comma_grouping_cols}
+            FROM {newupdate};

Review Comment:
   I'm probably missing something but don't we also need to drop the newupdate 
table when warm_start is set to true ? We probably are but just wanted to make 
sure



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -235,77 +257,92 @@ def wcc(schema_madlib, vertex_table, vertex_id, 
edge_table, edge_args,
         join_grouping_cols = _check_groups(subq, distinct_grp_table, 
grouping_cols_list)
         group_by_clause_newupdate = ('{0}, 
{1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
+        select_grouping_cols = ',' + subq_prefixed_grouping_cols
+
+        if not warm_start:
+            new_update_sql = """
+                CREATE TABLE {newupdate} AS
+                SELECT {subq}.{vertex_id},
+                        CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                        {select_grouping_cols}
+                FROM {distinct_grp_table} INNER JOIN (
+                    SELECT {grouping_cols_comma} {src} AS {vertex_id}
+                    FROM {edge_table}
+                    UNION
+                    SELECT {grouping_cols_comma} {dest} AS {vertex_id}
+                    FROM {edge_inverse}
+                ) {subq}
+                ON {join_grouping_cols}
+                GROUP BY {group_by_clause_newupdate}
+                {distribution};
+            """.format(**locals())
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_table}.{vertex_id},
+                        CAST({vertex_table}.{single_id} AS BIGINT) AS 
{component_id}
+                        {comma_grouping_cols}
+                FROM {newupdate} INNER JOIN {vertex_table}
+                ON {vertex_table}.{vertex_id} = {newupdate}.{vertex_id}
+                {distribution};
+            """.format(**locals())
+
+        distinct_grp_sql = """
+            CREATE TABLE {distinct_grp_table} AS
+            SELECT DISTINCT {grouping_cols} FROM {edge_table};
 
-        grp_sql = """
-                CREATE TABLE {distinct_grp_table} AS
-                SELECT DISTINCT {grouping_cols} FROM {edge_table};
-            """
-        plpy.execute(grp_sql.format(**locals()))
-
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {subq}.{vertex_id},
-                    CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-                    {select_grouping_cols}
-            FROM {distinct_grp_table} INNER JOIN (
-                SELECT {grouping_cols_comma} {src} AS {vertex_id}
-                FROM {edge_table}
-                UNION
-                SELECT {grouping_cols_comma} {dest} AS {vertex_id}
-                FROM {edge_inverse}
-            ) {subq}
-            ON {join_grouping_cols}
-            GROUP BY {group_by_clause_newupdate}
-            {distribution};
-
-            DROP TABLE IF EXISTS {distinct_grp_table};
-
-        """.format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
-                   **locals())
-        plpy.execute(prep_sql)
-
-        message_sql = """
-            CREATE TABLE {message} AS
-            SELECT {vertex_table}.{vertex_id},
-                    CAST({vertex_table}.{single_id} AS BIGINT) AS 
{component_id}
-                    {comma_grouping_cols}
-            FROM {newupdate} INNER JOIN {vertex_table}
-            ON {vertex_table}.{vertex_id} = {newupdate}.{vertex_id}
-            {distribution};
         """
-        plpy.execute(message_sql.format(**locals()))
+
+        nodes_to_update_sql = """
+            SELECT SUM(cnt) AS cnt_sum
+            FROM (
+                SELECT COUNT(*) AS cnt
+                FROM {toupdate}
+                GROUP BY {grouping_cols}
+                ) t
+        """.format(**locals())
     else:
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
-
-            CREATE TABLE {message} AS
-            SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
-        """
-        plpy.execute(prep_sql.format(**locals()))
-
-    oldupdate_sql = """
-            CREATE TABLE {oldupdate} AS
-            SELECT {message}.{vertex_id},
-                    MIN({message}.{component_id}) AS {component_id}
-                    {comma_grouping_cols}
-            FROM {message}
-            GROUP BY {grouping_cols_comma} {vertex_id}
-            LIMIT 0
-            {distribution};
-    """
-    plpy.execute(oldupdate_sql.format(**locals()))
+        if not warm_start:
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS 
{component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
+            new_update_sql = """
+                CREATE TABLE {newupdate} AS
+                SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS 
{component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
+
+        nodes_to_update_sql = """
+            SELECT COUNT(*) AS cnt_sum FROM {toupdate}
+        """.format(**locals())
 
-    toupdate_sql = """
-            CREATE TABLE {toupdate} AS
-            SELECT * FROM {oldupdate}
-            {distribution};
-        """
-    plpy.execute(toupdate_sql.format(**locals()))
+    old_update_sql = """
+        CREATE TABLE {oldupdate} AS
+        SELECT {message}.{vertex_id},
+                MIN({message}.{component_id}) AS {component_id}
+                {comma_grouping_cols}
+        FROM {message}
+        GROUP BY {grouping_cols_comma} {vertex_id}
+        LIMIT 0
+        {distribution};
+    """
+    to_update_sql = """
+        CREATE TABLE {toupdate} AS
+        SELECT * FROM {oldupdate}
+        {distribution};
+    """
+    if is_platform_pg or not is_platform_gp6_or_up():

Review Comment:
   We should also explain the need for this if check. Also I think it might be 
easier to read the code if we flip the if check to
   ```
   if is_platform_gp6_or_up():
       plpy.execute((distinct_grp_sql + new_update_sql + msg_sql + 
old_update_sql + to_update_sql).format(**locals()))
   else:
       .....
   ```
   This should also achieve the same thing right ?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@madlib.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to