Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/269#discussion_r186729249
  
    --- Diff: src/ports/postgres/modules/stats/correlation.py_in ---
    @@ -166,101 +203,184 @@ def _populate_output_table(schema_madlib, 
source_table, output_table,
             start = time()
             col_len = len(col_names)
             col_names_as_text_array = py_list_to_sql_string(col_names, 
"varchar")
    -        temp_table = unique_string()
    +        # Create unique strings to be used in queries.
    +        coalesced_col_array = unique_string(desp='coalesced_col_array')
    +        mean_col = unique_string(desp='mean')
             if get_cov:
    -            function_name = "Covariance"
                 agg_str = """
                     (CASE WHEN count(*) > 0
    -                      THEN {0}.array_scalar_mult({0}.covariance_agg(x, 
mean),
    +                      THEN {0}.array_scalar_mult({0}.covariance_agg({1}, 
{2}),
                                                      1.0 / count(*)::double 
precision)
                           ELSE NULL
    -                END) """.format(schema_madlib)
    +                END) """.format(schema_madlib, coalesced_col_array, 
mean_col)
             else:
    -            function_name = "Correlation"
    -            agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib)
    +            agg_str = "{0}.correlation_agg({1}, {2})".format(schema_madlib,
    +                                                             
coalesced_col_array,
    +                                                             mean_col)
     
             cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col, 
"_avg"))
                             for col in col_names])
             avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col, 
"_avg"))
                             for col in col_names])
             avg_array = ','.join([str(add_postfix(col, "_avg")) for col in 
col_names])
    -        # actual computation
    -        sql1 = """
    -
    -            CREATE TEMP TABLE {temp_table} AS
    -            SELECT
    -                count(*) AS tot_cnt,
    -                mean,
    -                {agg_str} as cor_mat
    -            FROM
    -            (
    -                SELECT ARRAY[ {cols} ] AS x,
    -                        ARRAY [ {avg_array} ] AS mean
    -                FROM {source_table},
    +        # Create unique strings to be used in queries.
    +        tot_cnt = unique_string(desp='tot_cnt')
    +        cor_mat = unique_string(desp='cor_mat')
    +        temp_output_table = unique_string(desp='temp_output')
    +        subquery1 = unique_string(desp='subq1')
    +        subquery2 = unique_string(desp='subq2')
    +
    +        grouping_cols_comma = ''
    +        subquery_grouping_cols_comma = ''
    +        inner_group_by = ''
    +        # Cross join if there are no groups to consider
    +        join_condition = ' ON (1=1) '
    +
    +        if grouping_cols:
    +            group_col_list = split_quoted_delimited_str(grouping_cols)
    +            grouping_cols_comma = add_postfix(grouping_cols, ', ')
    +            subquery_grouping_cols_comma = get_table_qualified_col_str(
    +                                                subquery2, group_col_list) 
+ " , "
    +
    +            inner_group_by = " GROUP BY {0}".format(grouping_cols)
    +            join_condition = " USING ({0})".format(grouping_cols)
    +
    +        create_temp_output_table_query = """
    +                CREATE TEMP TABLE {temp_output_table} AS
    +                SELECT
    +                    {subquery_grouping_cols_comma}
    +                    count(*) AS {tot_cnt},
    +                    {mean_col},
    +                    {agg_str} AS {cor_mat}
    +                FROM
                     (
    -                    SELECT {avgs}
    -                    FROM {source_table}
    -                )sub1
    -            ) sub2
    -            GROUP BY mean
    -            """.format(**locals())
    +                    SELECT {grouping_cols_comma}
    +                           ARRAY[ {cols} ] AS {coalesced_col_array},
    +                           ARRAY [ {avg_array} ] AS {mean_col}
     
    -        plpy.execute(sql1)
    -
    -        # create summary table
    +                    FROM {source_table}
    +                    JOIN
    +                    (
    +                        SELECT {grouping_cols_comma} {avgs}
    +                        FROM {source_table}
    +                        {inner_group_by}
    +                    ) {subquery1}
    +                    {join_condition}
    +                ) {subquery2}
    +                GROUP BY {grouping_cols_comma} {mean_col}
    +                """.format(**locals())
    +        plpy.execute(create_temp_output_table_query)
    +
    +        # Prepare the query for converting the matrix into the lower 
triangle
    +        deconstruction_query = _create_deconstruction_query(schema_madlib,
    +                                                            col_names,
    +                                                            grouping_cols,
    +                                                            
temp_output_table,
    +                                                            cor_mat)
    +
    +        variable_subquery = unique_string(desp='variable_subq')
    +        matrix_subquery = unique_string(desp='matrix_subq')
    +        # create output table
    +        plpy.info(col_names)
    +        create_output_table_query = """
    +        CREATE TABLE {output_table} AS
    +        SELECT {grouping_cols_comma} column_position, variable, 
{target_cols}
    +        FROM
    +        (
    +            SELECT
    +                generate_series(1, {num_cols}) AS column_position,
    +                unnest({col_names_as_text_array}) AS variable
    +        ) {variable_subquery}
    +        JOIN
    +        (
    +            {deconstruction_query}
    +        ) {matrix_subquery}
    +        USING (column_position)
    +        """.format( num_cols=len(col_names),
    +                    target_cols=' , '.join(col_names),
    +                    **locals())
    +        plpy.execute(create_output_table_query)
    +
    +         # create summary table
             summary_table = add_postfix(output_table, "_summary")
    -        q_summary = """
    +        create_summary_table_query = """
                 CREATE TABLE {summary_table} AS
                 SELECT
                     '{function_name}'::varchar  AS method,
                     '{source_table}'::varchar   AS source,
                     '{output_table}'::varchar   AS output_table,
                     {col_names_as_text_array}   AS column_names,
    -                mean                        AS mean_vector,
    -                tot_cnt                     AS total_rows_processed
    -            FROM {temp_table}
    +                {grouping_cols_comma}
    +                {mean_col}                  AS mean_vector,
    +                {tot_cnt}                   AS total_rows_processed
    +            FROM {temp_output_table}
                 """.format(**locals())
    -
    -        plpy.execute(q_summary)
    -
    -        # create output table
    -        variable_list = []
    -        for k, c in enumerate(col_names):
    -            if k % 10 == 0:
    -                variable_list.append("\n                ")
    -            variable_list.append(str(c) + " float8")
    -            variable_list.append(",")
    -        variable_list_str = ''.join(variable_list[:-1])  # remove the last 
comma
    -
    -        plpy.execute("""
    -            CREATE TABLE {output_table} AS
    -            SELECT
    -                *
    -            FROM
    -            (
    -                SELECT
    -                    generate_series(1, {num_cols}) AS column_position,
    -                    unnest({col_names_as_text_array}) AS variable
    -            ) variable_subq
    -            JOIN
    -            (
    -                SELECT
    -                    *
    -                FROM
    -                    {schema_madlib}.__deconstruct_lower_triangle(
    -                        (SELECT cor_mat FROM {temp_table})
    -                    )
    -                    AS deconstructed(column_position integer, 
{variable_list_str})
    -            ) matrix_subq
    -            USING (column_position)
    -            """.format(num_cols=len(col_names), **locals()))
    +        plpy.execute(create_summary_table_query)
     
             # clean up and return
    -        plpy.execute("DROP TABLE {temp_table}".format(**locals()))
    +        plpy.execute("DROP TABLE IF EXISTS 
{temp_output_table}".format(**locals()))
    +
             end = time()
             return (output_table, len(col_names), end - start)
     # 
------------------------------------------------------------------------------
     
    +def _create_deconstruction_query(schema_madlib, col_names, grouping_cols,
    +                                 temp_output_table, cor_mat):
    +    """
    +    Creates the query to convert the matrix into the lower-traingular 
format.
    +
    +    Args:
    +        @param schema_madlib        Schema of MADlib
    +        @param col_names            Name of all columns to place in output 
table
    +        @param grouping_cols        Name of all columns to be used for 
grouping
    +        @param temp_output_table    Name of the temporary table that 
contains
    +                                    the matrix to deconstruct
    +        @param cor_mat              Name of column that containss the 
matrix
    +                                    to deconstruct
    +
    +    Returns:
    +        String (SQL querry for deconstructing the matrix)
    +    """
    +    # The matrix that holds the PCC computation must be converted to a
    +    # table capturing all pair wise PCC values. That is done using
    +    # a UDF named __deconstruct_lower_triangle.
    +    # With grouping, calling that UDF becomes a lot more complex, so
    +    # construct the query accordingly.
    +
    +    variable_list = []
    +    for k, c in enumerate(col_names):
    +        if k % 10 == 0:
    +            variable_list.append("\n                ")
    +        variable_list.append(str(c) + " float8")
    +        variable_list.append(",")
    +    variable_list_str = ''.join(variable_list[:-1])  # remove the last 
comma
    +
    --- End diff --
    
    I realize that above 8 lines are from existing code but wondering if below 
is easier to understand and read: 
    ```
    COL_WIDTH = 10
    # split the col_names to equal size sets with newline between to prevent a 
long query 
    # Build a 2d array of the col_names, each inner array with COL_WIDTH number 
of names. 
    col_names_split = [col_names[x : x + COL_WIDTH] for x in range(0, 
len(col_names), COL_WIDTH)]
    variable_list_str = ',\n                 '.join([', '.join(i) for i in 
col_names_split])
    ```


---

Reply via email to