You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by "kaknikhil (via GitHub)" <gi...@apache.org> on 2023/02/17 03:43:32 UTC

[GitHub] [madlib] kaknikhil commented on a diff in pull request #594: WCC: Add warm start

kaknikhil commented on code in PR #594:
URL: https://github.com/apache/madlib/pull/594#discussion_r1107670385


##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -47,14 +48,18 @@ from graph_utils import validate_output_and_summary_tables
 
 def validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
                       edge_params, out_table, out_table_summary,
-                      grouping_cols_list, module_name):
+                      grouping_cols_list, warm_start, module_name):
     """
     Function to validate input parameters for wcc
     """
     validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, module_name)
-    _assert(not table_exists(out_table_summary),
-            "Graph {module_name}: Output summary table already exists!".format(**locals()))
+                          out_table, module_name, warm_start)
+    if not warm_start:

Review Comment:
   Shouldn't the summary table validation also happen in `validate_graph_coding` ?



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:

Review Comment:
   1. I think it will be a good idea to explain the workflow for wcc when warm_start is set to true vs when it's set to false. This could be added to the commit message, PR description and also as a comment in the python file. 
   2. Should we also update the design doc and user docs ?



##########
src/ports/postgres/modules/graph/wcc.sql_in:
##########
@@ -115,6 +117,17 @@ weakly connected components are generated for all data
 (single graph).
 @note Expressions are not currently supported for 'grouping_cols'.</dd>
 
+<dt>iteration_limit (optional)</dt>
+<dd>INTEGER, default: NULL. Maximum number of iterations to run wcc. This

Review Comment:
   We should explicitly call out all the tables that get created for various scenarios of iteration limit and nodes_to_update



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """
+            CREATE TABLE {newupdate} AS SELECT * FROM {out_table};
+        """.format(**locals())
+        msg_sql = """
+            CREATE TABLE {message} AS SELECT * FROM {out_table_message};
+        """.format(**locals())
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            new_update_sql += """

Review Comment:
   Maybe use `.format` instead of `+`. Also applies to other places where we use `+`



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -366,44 +406,61 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
         # found in the current iteration.
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-            plpy.execute(loop_sql.format(**locals()))
-
-            if grouping_cols:
-                nodes_to_update = plpy.execute("""
-                                    SELECT SUM(cnt) AS cnt_sum
-                                    FROM (
-                                        SELECT COUNT(*) AS cnt
-                                        FROM {toupdate}
-                                        GROUP BY {grouping_cols}
-                                    ) t
-                    """.format(**locals()))[0]["cnt_sum"]
-            else:
-                nodes_to_update = plpy.execute("""
-                                    SELECT COUNT(*) AS cnt FROM {toupdate}
-                                """.format(**locals()))[0]["cnt"]
+            nodes_to_update = plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"]
+            iteration_counter += 1
+
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
         plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
 
-    rename_table(schema_madlib, newupdate, out_table)
-    if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
-        plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals()))
+    if not warm_start:
+        rename_table(schema_madlib, newupdate, out_table)
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals()))
+    else:
+        plpy.execute("""
+            TRUNCATE TABLE {out_table};

Review Comment:
   Is there a performance reason for doing a truncate and insert rather than drop and rename ? If yes, we should add that as a comment here 



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 'user_id', 2);

Review Comment:
   I think there's a minor bug. Consider the following scenario
   
   Add data
   ```
   test=# select * from "EDGE";
    src_node | dest_node | user_id
   ----------+-----------+---------
           1 |         2 |       1
           1 |         3 |       1
           6 |         3 |       1
           5 |         4 |       1
           5 |         8 |       1
           2 |         3 |       1
           7 |         6 |       1
           8 |         4 |       1
   (8 rows)
   
   Time: 4.080 ms
   test=# select * from vertex ;
    vertex_id
   -----------
            1
            2
            3
            4
            7
            8
            5
            6
   (8 rows)
   ```
   
   Run with no grouping col but an iteration limit of 1
   ```
   SELECT madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out', NULL, 1);
   ```
   
   Now without dropping any tables, run the same query again with warm start set to true
   ```
   SELECT madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out', NULL, 1, TRUE);
   ERROR:  spiexceptions.DuplicateTable: relation "wcc_out_message" already exists
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "weakly_connected_components", line 21, in <module>
       return wcc.wcc(**globals())
     PL/Python function "weakly_connected_components", line 432, in wcc
     PL/Python function "weakly_connected_components", line 1212, in rename_table
     PL/Python function "weakly_connected_components", line 1261, in __do_rename_and_get_new_name
   PL/Python function "weakly_connected_components"
   ```
   The interesting thing is that if you set the iteration limit to 4 for the second wcc query, it does not error out. I think that's because it takes 5 iterations for the `nodes_to_update` to become 0. 
   
   We should also dd a test case for this scenario  (This is based on the assumption that it takes 5 iteration to update all the nodes)
   1. Run 2 iterations without warm start
   2. Run 2 more with warm start set to true 
   3. Run 1 or 2 more with warm start set to true.
   For the first two runs, we should at the very least assert that nodes_to_update > 0 and for the final run, we should assert the contents of the out table and also that nodes_to_update = 0
   



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """

Review Comment:
   I think we should initialize all these sql related variables to "" outside the if check. That way we won't ever run into `variable not defined` issues



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -211,6 +221,18 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     else:
         edge_inverse = edge_table
 
+    if warm_start:
+        new_update_sql = """
+            CREATE TABLE {newupdate} AS SELECT * FROM {out_table};
+        """.format(**locals())
+        msg_sql = """
+            CREATE TABLE {message} AS SELECT * FROM {out_table_message};
+        """.format(**locals())
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':

Review Comment:
   Just curious, why do we need this if check ? Why is there a need to rename to `{vertex_id}` ?



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -90,6 +96,8 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     edge_params = extract_keyvalue_params(
         edge_args, params_types, default_args)
 
+    if iteration_limit is None or iteration_limit == 0:

Review Comment:
   Do we also need to prevent against <0 values ?



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 'user_id', 2);

Review Comment:
   Other observations:
   1. There is a table that doesn't get dropped if grouping cols is passed in
   2. Run without an iteration limit so that eventually there are no nodes to update
   ```
   SELECT madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out', 'user_id',0,TRUE);
   ```
   Now run with warm start set to true
   ```
   SELECT madlib.weakly_connected_components('vertex','vertex_id','"EDGE"','src=src_node,dest=dest_node','wcc_out', 'user_id',2,TRUE);
   ERROR:  spiexceptions.UndefinedTable: relation "wcc_out_message" does not exist
   LINE 2: ...age15774726_1676600954_3919632__ AS SELECT * FROM wcc_out_me...
                                                                ^
   QUERY:
               CREATE TABLE __madlib_temp_message15774726_1676600954_3919632__ AS SELECT * FROM wcc_out_message;
   
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "weakly_connected_components", line 21, in <module>
       return wcc.wcc(**globals())
     PL/Python function "weakly_connected_components", line 340, in wcc
   PL/Python function "weakly_connected_components"
   ```
   Instead of this exception, we could print a message along the lines of "We have already updated all the nodes, nothing to do"
   3. If the user passes in -1 as the iteration_limit, the `nodes_to_update` column in out_summary table is always 1 which doesn't seem right.



##########
src/ports/postgres/modules/graph/graph_utils.py_in:
##########
@@ -74,14 +74,18 @@ def validate_output_and_summary_tables(model_out_table, module_name,
                 "Graph WCC: Output table {0} already exists.".format(out_table))
 
 def validate_graph_coding(vertex_table, vertex_id, edge_table, edge_params,
-                          out_table, func_name, **kwargs):
+                          out_table, func_name, warm_start = False, **kwargs):
     """
     Validates graph tables (vertex and edge) as well as the output table.
     """
     _assert(out_table and out_table.strip().lower() not in ('null', ''),
-            "Graph {func_name}: Invalid output table name!".format(**locals()))
-    _assert(not table_exists(out_table),
-            "Graph {func_name}: Output table already exists!".format(**locals()))
+                "Graph {func_name}: Invalid output table name!".format(**locals()))
+    if not warm_start:

Review Comment:
   Flipping this if check might improve code readability but I'll leave it up to you to decide



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -158,20 +166,21 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
     out_table_summary = ''
     if out_table:
         out_table_summary = add_postfix(out_table, "_summary")
+        out_table_message = add_postfix(out_table, "_message")
     grouping_cols_list = split_quoted_delimited_str(grouping_cols)
     validate_wcc_args(schema_madlib, vertex_table, vertex_id, edge_table,
                       edge_params, out_table, out_table_summary,
-                      grouping_cols_list, 'Weakly Connected Components')
+                      grouping_cols_list, warm_start, 'Weakly Connected Components')
 
     vertex_view_sql = vertex_view_sql.format(**locals())
-    plpy.execute(vertex_view_sql)
 
-    sql = """
+    edge_view_sql = """
         CREATE VIEW {edge_view} AS
         SELECT {src} AS src, {dest} AS dest {grouping_sql}
-        FROM {edge_table}
+        FROM {edge_table};
         """.format(**locals())
-    plpy.execute(sql)
+
+    plpy.execute(vertex_view_sql + edge_view_sql)

Review Comment:
   We should add a comment explaining why we need to run all these sqls in the same plpy.execute. 



##########
src/ports/postgres/modules/graph/test/wcc.sql_in:
##########
@@ -276,3 +276,12 @@ SELECT graph_wcc_num_cpts(
 SELECT assert(relative_error(num_components, 3) < 0.00001,
         'Weakly Connected Components: Incorrect largest component value.'
     ) FROM count_table WHERE user_id1=1;
+
+DROP TABLE IF EXISTS wcc_warm_start_out, wcc_warm_start_out_summary;
+SELECT weakly_connected_components('v2',NULL,'e2',NULL,'wcc_warm_start_out', 'user_id', 2);

Review Comment:
   Test improvements/notes:
   1. This is a test for wcc with warm start and grouping cols right ? Shouldn't we also add a test with warm start and no grouping cols ?
   2. We should also add an assert for the output table and the summary table. That will help us in catching regressions(if any)
   3. Can there be any issues if the user upgrades madlib and then uses the new function with warm start set to true on the previous out tables ? We might have to explicitly call this out as well



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -235,77 +257,92 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
         join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list)
         group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
+        select_grouping_cols = ',' + subq_prefixed_grouping_cols
+
+        if not warm_start:
+            new_update_sql = """

Review Comment:
   We should also mention in the commit message the reason behind combining all these sqls into one plpy execute



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -366,44 +406,61 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
         # found in the current iteration.
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
 
-            plpy.execute(loop_sql.format(**locals()))
-
-            if grouping_cols:
-                nodes_to_update = plpy.execute("""
-                                    SELECT SUM(cnt) AS cnt_sum
-                                    FROM (
-                                        SELECT COUNT(*) AS cnt
-                                        FROM {toupdate}
-                                        GROUP BY {grouping_cols}
-                                    ) t
-                    """.format(**locals()))[0]["cnt_sum"]
-            else:
-                nodes_to_update = plpy.execute("""
-                                    SELECT COUNT(*) AS cnt FROM {toupdate}
-                                """.format(**locals()))[0]["cnt"]
+            nodes_to_update = plpy.execute(loop_sql.format(**locals()))[0]["cnt_sum"]
+            iteration_counter += 1
+
 
     if not is_platform_pg():
         # Drop intermediate table created for Greenplum
         plpy.execute("DROP TABLE IF EXISTS {0}".format(edge_inverse))
 
-    rename_table(schema_madlib, newupdate, out_table)
-    if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
-        plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals()))
+    if not warm_start:
+        rename_table(schema_madlib, newupdate, out_table)
+        if vertex_type != "BIGINT[]" and vertex_id_in and vertex_id_in != 'id':
+            plpy.execute("ALTER TABLE {out_table} RENAME COLUMN id TO {vertex_id_in}".format(**locals()))
+    else:
+        plpy.execute("""
+            TRUNCATE TABLE {out_table};
+            INSERT INTO {out_table}
+            SELECT
+                {vertex_id} AS {vertex_id_in},
+                {component_id}
+                {comma_grouping_cols}
+            FROM {newupdate};

Review Comment:
   I'm probably missing something but don't we also need to drop the newupdate table when warm_start is set to true ? We probably are but just wanted to make sure



##########
src/ports/postgres/modules/graph/wcc.py_in:
##########
@@ -235,77 +257,92 @@ def wcc(schema_madlib, vertex_table, vertex_id, edge_table, edge_args,
         join_grouping_cols = _check_groups(subq, distinct_grp_table, grouping_cols_list)
         group_by_clause_newupdate = ('{0}, {1}.{2}'.format(subq_prefixed_grouping_cols,
                                                            subq, vertex_id))
+        select_grouping_cols = ',' + subq_prefixed_grouping_cols
+
+        if not warm_start:
+            new_update_sql = """
+                CREATE TABLE {newupdate} AS
+                SELECT {subq}.{vertex_id},
+                        CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                        {select_grouping_cols}
+                FROM {distinct_grp_table} INNER JOIN (
+                    SELECT {grouping_cols_comma} {src} AS {vertex_id}
+                    FROM {edge_table}
+                    UNION
+                    SELECT {grouping_cols_comma} {dest} AS {vertex_id}
+                    FROM {edge_inverse}
+                ) {subq}
+                ON {join_grouping_cols}
+                GROUP BY {group_by_clause_newupdate}
+                {distribution};
+            """.format(**locals())
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_table}.{vertex_id},
+                        CAST({vertex_table}.{single_id} AS BIGINT) AS {component_id}
+                        {comma_grouping_cols}
+                FROM {newupdate} INNER JOIN {vertex_table}
+                ON {vertex_table}.{vertex_id} = {newupdate}.{vertex_id}
+                {distribution};
+            """.format(**locals())
+
+        distinct_grp_sql = """
+            CREATE TABLE {distinct_grp_table} AS
+            SELECT DISTINCT {grouping_cols} FROM {edge_table};
 
-        grp_sql = """
-                CREATE TABLE {distinct_grp_table} AS
-                SELECT DISTINCT {grouping_cols} FROM {edge_table};
-            """
-        plpy.execute(grp_sql.format(**locals()))
-
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {subq}.{vertex_id},
-                    CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-                    {select_grouping_cols}
-            FROM {distinct_grp_table} INNER JOIN (
-                SELECT {grouping_cols_comma} {src} AS {vertex_id}
-                FROM {edge_table}
-                UNION
-                SELECT {grouping_cols_comma} {dest} AS {vertex_id}
-                FROM {edge_inverse}
-            ) {subq}
-            ON {join_grouping_cols}
-            GROUP BY {group_by_clause_newupdate}
-            {distribution};
-
-            DROP TABLE IF EXISTS {distinct_grp_table};
-
-        """.format(select_grouping_cols=',' + subq_prefixed_grouping_cols,
-                   **locals())
-        plpy.execute(prep_sql)
-
-        message_sql = """
-            CREATE TABLE {message} AS
-            SELECT {vertex_table}.{vertex_id},
-                    CAST({vertex_table}.{single_id} AS BIGINT) AS {component_id}
-                    {comma_grouping_cols}
-            FROM {newupdate} INNER JOIN {vertex_table}
-            ON {vertex_table}.{vertex_id} = {newupdate}.{vertex_id}
-            {distribution};
         """
-        plpy.execute(message_sql.format(**locals()))
+
+        nodes_to_update_sql = """
+            SELECT SUM(cnt) AS cnt_sum
+            FROM (
+                SELECT COUNT(*) AS cnt
+                FROM {toupdate}
+                GROUP BY {grouping_cols}
+                ) t
+        """.format(**locals())
     else:
-        prep_sql = """
-            CREATE TABLE {newupdate} AS
-            SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
-
-            CREATE TABLE {message} AS
-            SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id}
-            FROM {vertex_table}
-            {distribution};
-        """
-        plpy.execute(prep_sql.format(**locals()))
-
-    oldupdate_sql = """
-            CREATE TABLE {oldupdate} AS
-            SELECT {message}.{vertex_id},
-                    MIN({message}.{component_id}) AS {component_id}
-                    {comma_grouping_cols}
-            FROM {message}
-            GROUP BY {grouping_cols_comma} {vertex_id}
-            LIMIT 0
-            {distribution};
-    """
-    plpy.execute(oldupdate_sql.format(**locals()))
+        if not warm_start:
+            msg_sql = """
+                CREATE TABLE {message} AS
+                SELECT {vertex_id}, CAST({single_id} AS BIGINT) AS {component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
+            new_update_sql = """
+                CREATE TABLE {newupdate} AS
+                SELECT {vertex_id}, CAST({BIGINT_MAX} AS BIGINT) AS {component_id}
+                FROM {vertex_table}
+                {distribution};
+            """.format(**locals())
+
+        nodes_to_update_sql = """
+            SELECT COUNT(*) AS cnt_sum FROM {toupdate}
+        """.format(**locals())
 
-    toupdate_sql = """
-            CREATE TABLE {toupdate} AS
-            SELECT * FROM {oldupdate}
-            {distribution};
-        """
-    plpy.execute(toupdate_sql.format(**locals()))
+    old_update_sql = """
+        CREATE TABLE {oldupdate} AS
+        SELECT {message}.{vertex_id},
+                MIN({message}.{component_id}) AS {component_id}
+                {comma_grouping_cols}
+        FROM {message}
+        GROUP BY {grouping_cols_comma} {vertex_id}
+        LIMIT 0
+        {distribution};
+    """
+    to_update_sql = """
+        CREATE TABLE {toupdate} AS
+        SELECT * FROM {oldupdate}
+        {distribution};
+    """
+    if is_platform_pg or not is_platform_gp6_or_up():

Review Comment:
   We should also explain the need for this if check. Also I think it might be easier to read the code if we flip the if check to
   ```
   if is_platform_gp6_or_up():
       plpy.execute((distinct_grp_sql + new_update_sql + msg_sql + old_update_sql + to_update_sql).format(**locals()))
   else:
       .....
   ```
   This should also achieve the same thing right ?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@madlib.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org