Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/269#discussion_r186722850
--- Diff: src/ports/postgres/modules/stats/correlation.py_in ---
@@ -166,101 +203,184 @@ def _populate_output_table(schema_madlib,
source_table, output_table,
start = time()
col_len = len(col_names)
col_names_as_text_array = py_list_to_sql_string(col_names,
"varchar")
- temp_table = unique_string()
+ # Create unique strings to be used in queries.
+ coalesced_col_array = unique_string(desp='coalesced_col_array')
+ mean_col = unique_string(desp='mean')
if get_cov:
- function_name = "Covariance"
agg_str = """
(CASE WHEN count(*) > 0
- THEN {0}.array_scalar_mult({0}.covariance_agg(x,
mean),
+ THEN {0}.array_scalar_mult({0}.covariance_agg({1},
{2}),
1.0 / count(*)::double
precision)
ELSE NULL
- END) """.format(schema_madlib)
+ END) """.format(schema_madlib, coalesced_col_array,
mean_col)
else:
- function_name = "Correlation"
- agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib)
+ agg_str = "{0}.correlation_agg({1}, {2})".format(schema_madlib,
+
coalesced_col_array,
+ mean_col)
cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col,
"_avg"))
for col in col_names])
avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col,
"_avg"))
for col in col_names])
avg_array = ','.join([str(add_postfix(col, "_avg")) for col in
col_names])
- # actual computation
- sql1 = """
-
- CREATE TEMP TABLE {temp_table} AS
- SELECT
- count(*) AS tot_cnt,
- mean,
- {agg_str} as cor_mat
- FROM
- (
- SELECT ARRAY[ {cols} ] AS x,
- ARRAY [ {avg_array} ] AS mean
- FROM {source_table},
+ # Create unique strings to be used in queries.
+ tot_cnt = unique_string(desp='tot_cnt')
+ cor_mat = unique_string(desp='cor_mat')
+ temp_output_table = unique_string(desp='temp_output')
+ subquery1 = unique_string(desp='subq1')
+ subquery2 = unique_string(desp='subq2')
+
+ grouping_cols_comma = ''
+ subquery_grouping_cols_comma = ''
+ inner_group_by = ''
+ # Cross join if there are no groups to consider
+ join_condition = ' ON (1=1) '
+
+ if grouping_cols:
+ group_col_list = split_quoted_delimited_str(grouping_cols)
+ grouping_cols_comma = add_postfix(grouping_cols, ', ')
+ subquery_grouping_cols_comma = get_table_qualified_col_str(
+ subquery2, group_col_list)
+ " , "
+
+ inner_group_by = " GROUP BY {0}".format(grouping_cols)
+ join_condition = " USING ({0})".format(grouping_cols)
+
+ create_temp_output_table_query = """
+ CREATE TEMP TABLE {temp_output_table} AS
+ SELECT
+ {subquery_grouping_cols_comma}
+ count(*) AS {tot_cnt},
+ {mean_col},
+ {agg_str} AS {cor_mat}
+ FROM
(
- SELECT {avgs}
- FROM {source_table}
- )sub1
- ) sub2
- GROUP BY mean
- """.format(**locals())
+ SELECT {grouping_cols_comma}
+ ARRAY[ {cols} ] AS {coalesced_col_array},
+ ARRAY [ {avg_array} ] AS {mean_col}
- plpy.execute(sql1)
-
- # create summary table
+ FROM {source_table}
+ JOIN
+ (
+ SELECT {grouping_cols_comma} {avgs}
+ FROM {source_table}
+ {inner_group_by}
+ ) {subquery1}
+ {join_condition}
+ ) {subquery2}
+ GROUP BY {grouping_cols_comma} {mean_col}
+ """.format(**locals())
+ plpy.execute(create_temp_output_table_query)
+
+ # Prepare the query for converting the matrix into the lower
triangle
+ deconstruction_query = _create_deconstruction_query(schema_madlib,
+ col_names,
+ grouping_cols,
+
temp_output_table,
+ cor_mat)
+
+ variable_subquery = unique_string(desp='variable_subq')
+ matrix_subquery = unique_string(desp='matrix_subq')
+ # create output table
+ plpy.info(col_names)
--- End diff --
Info line can be deleted
---