Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/269#discussion_r186722850
  
    --- Diff: src/ports/postgres/modules/stats/correlation.py_in ---
    @@ -166,101 +203,184 @@ def _populate_output_table(schema_madlib, 
source_table, output_table,
             start = time()
             col_len = len(col_names)
             col_names_as_text_array = py_list_to_sql_string(col_names, 
"varchar")
    -        temp_table = unique_string()
    +        # Create unique strings to be used in queries.
    +        coalesced_col_array = unique_string(desp='coalesced_col_array')
    +        mean_col = unique_string(desp='mean')
             if get_cov:
    -            function_name = "Covariance"
                 agg_str = """
                     (CASE WHEN count(*) > 0
    -                      THEN {0}.array_scalar_mult({0}.covariance_agg(x, 
mean),
    +                      THEN {0}.array_scalar_mult({0}.covariance_agg({1}, 
{2}),
                                                      1.0 / count(*)::double 
precision)
                           ELSE NULL
    -                END) """.format(schema_madlib)
    +                END) """.format(schema_madlib, coalesced_col_array, 
mean_col)
             else:
    -            function_name = "Correlation"
    -            agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib)
    +            agg_str = "{0}.correlation_agg({1}, {2})".format(schema_madlib,
    +                                                             
coalesced_col_array,
    +                                                             mean_col)
     
             cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col, 
"_avg"))
                             for col in col_names])
             avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col, 
"_avg"))
                             for col in col_names])
             avg_array = ','.join([str(add_postfix(col, "_avg")) for col in 
col_names])
    -        # actual computation
    -        sql1 = """
    -
    -            CREATE TEMP TABLE {temp_table} AS
    -            SELECT
    -                count(*) AS tot_cnt,
    -                mean,
    -                {agg_str} as cor_mat
    -            FROM
    -            (
    -                SELECT ARRAY[ {cols} ] AS x,
    -                        ARRAY [ {avg_array} ] AS mean
    -                FROM {source_table},
    +        # Create unique strings to be used in queries.
    +        tot_cnt = unique_string(desp='tot_cnt')
    +        cor_mat = unique_string(desp='cor_mat')
    +        temp_output_table = unique_string(desp='temp_output')
    +        subquery1 = unique_string(desp='subq1')
    +        subquery2 = unique_string(desp='subq2')
    +
    +        grouping_cols_comma = ''
    +        subquery_grouping_cols_comma = ''
    +        inner_group_by = ''
    +        # Cross join if there are no groups to consider
    +        join_condition = ' ON (1=1) '
    +
    +        if grouping_cols:
    +            group_col_list = split_quoted_delimited_str(grouping_cols)
    +            grouping_cols_comma = add_postfix(grouping_cols, ', ')
    +            subquery_grouping_cols_comma = get_table_qualified_col_str(
    +                                                subquery2, group_col_list) 
+ " , "
    +
    +            inner_group_by = " GROUP BY {0}".format(grouping_cols)
    +            join_condition = " USING ({0})".format(grouping_cols)
    +
    +        create_temp_output_table_query = """
    +                CREATE TEMP TABLE {temp_output_table} AS
    +                SELECT
    +                    {subquery_grouping_cols_comma}
    +                    count(*) AS {tot_cnt},
    +                    {mean_col},
    +                    {agg_str} AS {cor_mat}
    +                FROM
                     (
    -                    SELECT {avgs}
    -                    FROM {source_table}
    -                )sub1
    -            ) sub2
    -            GROUP BY mean
    -            """.format(**locals())
    +                    SELECT {grouping_cols_comma}
    +                           ARRAY[ {cols} ] AS {coalesced_col_array},
    +                           ARRAY [ {avg_array} ] AS {mean_col}
     
    -        plpy.execute(sql1)
    -
    -        # create summary table
    +                    FROM {source_table}
    +                    JOIN
    +                    (
    +                        SELECT {grouping_cols_comma} {avgs}
    +                        FROM {source_table}
    +                        {inner_group_by}
    +                    ) {subquery1}
    +                    {join_condition}
    +                ) {subquery2}
    +                GROUP BY {grouping_cols_comma} {mean_col}
    +                """.format(**locals())
    +        plpy.execute(create_temp_output_table_query)
    +
    +        # Prepare the query for converting the matrix into the lower 
triangle
    +        deconstruction_query = _create_deconstruction_query(schema_madlib,
    +                                                            col_names,
    +                                                            grouping_cols,
    +                                                            
temp_output_table,
    +                                                            cor_mat)
    +
    +        variable_subquery = unique_string(desp='variable_subq')
    +        matrix_subquery = unique_string(desp='matrix_subq')
    +        # create output table
    +        plpy.info(col_names)
    --- End diff --
    
    Info line can be deleted


---

Reply via email to