Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/269#discussion_r186729249
--- Diff: src/ports/postgres/modules/stats/correlation.py_in ---
@@ -166,101 +203,184 @@ def _populate_output_table(schema_madlib,
source_table, output_table,
start = time()
col_len = len(col_names)
col_names_as_text_array = py_list_to_sql_string(col_names,
"varchar")
- temp_table = unique_string()
+ # Create unique strings to be used in queries.
+ coalesced_col_array = unique_string(desp='coalesced_col_array')
+ mean_col = unique_string(desp='mean')
if get_cov:
- function_name = "Covariance"
agg_str = """
(CASE WHEN count(*) > 0
- THEN {0}.array_scalar_mult({0}.covariance_agg(x,
mean),
+ THEN {0}.array_scalar_mult({0}.covariance_agg({1},
{2}),
1.0 / count(*)::double
precision)
ELSE NULL
- END) """.format(schema_madlib)
+ END) """.format(schema_madlib, coalesced_col_array,
mean_col)
else:
- function_name = "Correlation"
- agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib)
+ agg_str = "{0}.correlation_agg({1}, {2})".format(schema_madlib,
+
coalesced_col_array,
+ mean_col)
cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col,
"_avg"))
for col in col_names])
avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col,
"_avg"))
for col in col_names])
avg_array = ','.join([str(add_postfix(col, "_avg")) for col in
col_names])
- # actual computation
- sql1 = """
-
- CREATE TEMP TABLE {temp_table} AS
- SELECT
- count(*) AS tot_cnt,
- mean,
- {agg_str} as cor_mat
- FROM
- (
- SELECT ARRAY[ {cols} ] AS x,
- ARRAY [ {avg_array} ] AS mean
- FROM {source_table},
+ # Create unique strings to be used in queries.
+ tot_cnt = unique_string(desp='tot_cnt')
+ cor_mat = unique_string(desp='cor_mat')
+ temp_output_table = unique_string(desp='temp_output')
+ subquery1 = unique_string(desp='subq1')
+ subquery2 = unique_string(desp='subq2')
+
+ grouping_cols_comma = ''
+ subquery_grouping_cols_comma = ''
+ inner_group_by = ''
+ # Cross join if there are no groups to consider
+ join_condition = ' ON (1=1) '
+
+ if grouping_cols:
+ group_col_list = split_quoted_delimited_str(grouping_cols)
+ grouping_cols_comma = add_postfix(grouping_cols, ', ')
+ subquery_grouping_cols_comma = get_table_qualified_col_str(
+ subquery2, group_col_list)
+ " , "
+
+ inner_group_by = " GROUP BY {0}".format(grouping_cols)
+ join_condition = " USING ({0})".format(grouping_cols)
+
+ create_temp_output_table_query = """
+ CREATE TEMP TABLE {temp_output_table} AS
+ SELECT
+ {subquery_grouping_cols_comma}
+ count(*) AS {tot_cnt},
+ {mean_col},
+ {agg_str} AS {cor_mat}
+ FROM
(
- SELECT {avgs}
- FROM {source_table}
- )sub1
- ) sub2
- GROUP BY mean
- """.format(**locals())
+ SELECT {grouping_cols_comma}
+ ARRAY[ {cols} ] AS {coalesced_col_array},
+ ARRAY [ {avg_array} ] AS {mean_col}
- plpy.execute(sql1)
-
- # create summary table
+ FROM {source_table}
+ JOIN
+ (
+ SELECT {grouping_cols_comma} {avgs}
+ FROM {source_table}
+ {inner_group_by}
+ ) {subquery1}
+ {join_condition}
+ ) {subquery2}
+ GROUP BY {grouping_cols_comma} {mean_col}
+ """.format(**locals())
+ plpy.execute(create_temp_output_table_query)
+
+ # Prepare the query for converting the matrix into the lower
triangle
+ deconstruction_query = _create_deconstruction_query(schema_madlib,
+ col_names,
+ grouping_cols,
+
temp_output_table,
+ cor_mat)
+
+ variable_subquery = unique_string(desp='variable_subq')
+ matrix_subquery = unique_string(desp='matrix_subq')
+ # create output table
+ plpy.info(col_names)
+ create_output_table_query = """
+ CREATE TABLE {output_table} AS
+ SELECT {grouping_cols_comma} column_position, variable,
{target_cols}
+ FROM
+ (
+ SELECT
+ generate_series(1, {num_cols}) AS column_position,
+ unnest({col_names_as_text_array}) AS variable
+ ) {variable_subquery}
+ JOIN
+ (
+ {deconstruction_query}
+ ) {matrix_subquery}
+ USING (column_position)
+ """.format( num_cols=len(col_names),
+ target_cols=' , '.join(col_names),
+ **locals())
+ plpy.execute(create_output_table_query)
+
+ # create summary table
summary_table = add_postfix(output_table, "_summary")
- q_summary = """
+ create_summary_table_query = """
CREATE TABLE {summary_table} AS
SELECT
'{function_name}'::varchar AS method,
'{source_table}'::varchar AS source,
'{output_table}'::varchar AS output_table,
{col_names_as_text_array} AS column_names,
- mean AS mean_vector,
- tot_cnt AS total_rows_processed
- FROM {temp_table}
+ {grouping_cols_comma}
+ {mean_col} AS mean_vector,
+ {tot_cnt} AS total_rows_processed
+ FROM {temp_output_table}
""".format(**locals())
-
- plpy.execute(q_summary)
-
- # create output table
- variable_list = []
- for k, c in enumerate(col_names):
- if k % 10 == 0:
- variable_list.append("\n ")
- variable_list.append(str(c) + " float8")
- variable_list.append(",")
- variable_list_str = ''.join(variable_list[:-1]) # remove the last
comma
-
- plpy.execute("""
- CREATE TABLE {output_table} AS
- SELECT
- *
- FROM
- (
- SELECT
- generate_series(1, {num_cols}) AS column_position,
- unnest({col_names_as_text_array}) AS variable
- ) variable_subq
- JOIN
- (
- SELECT
- *
- FROM
- {schema_madlib}.__deconstruct_lower_triangle(
- (SELECT cor_mat FROM {temp_table})
- )
- AS deconstructed(column_position integer,
{variable_list_str})
- ) matrix_subq
- USING (column_position)
- """.format(num_cols=len(col_names), **locals()))
+ plpy.execute(create_summary_table_query)
# clean up and return
- plpy.execute("DROP TABLE {temp_table}".format(**locals()))
+ plpy.execute("DROP TABLE IF EXISTS
{temp_output_table}".format(**locals()))
+
end = time()
return (output_table, len(col_names), end - start)
#
------------------------------------------------------------------------------
+def _create_deconstruction_query(schema_madlib, col_names, grouping_cols,
+ temp_output_table, cor_mat):
+ """
+ Creates the query to convert the matrix into the lower-traingular
format.
+
+ Args:
+ @param schema_madlib Schema of MADlib
+ @param col_names Name of all columns to place in output
table
+ @param grouping_cols Name of all columns to be used for
grouping
+ @param temp_output_table Name of the temporary table that
contains
+ the matrix to deconstruct
+ @param cor_mat Name of column that containss the
matrix
+ to deconstruct
+
+ Returns:
+ String (SQL querry for deconstructing the matrix)
+ """
+ # The matrix that holds the PCC computation must be converted to a
+ # table capturing all pair wise PCC values. That is done using
+ # a UDF named __deconstruct_lower_triangle.
+ # With grouping, calling that UDF becomes a lot more complex, so
+ # construct the query accordingly.
+
+ variable_list = []
+ for k, c in enumerate(col_names):
+ if k % 10 == 0:
+ variable_list.append("\n ")
+ variable_list.append(str(c) + " float8")
+ variable_list.append(",")
+ variable_list_str = ''.join(variable_list[:-1]) # remove the last
comma
+
--- End diff --
I realize that above 8 lines are from existing code but wondering if below
is easier to understand and read:
```
COL_WIDTH = 10
# split the col_names to equal size sets with newline between to prevent a
long query
# Build a 2d array of the col_names, each inner array with COL_WIDTH number
of names.
col_names_split = [col_names[x : x + COL_WIDTH] for x in range(0,
len(col_names), COL_WIDTH)]
variable_list_str = ',\n '.join([', '.join(i) for i in
col_names_split])
```
---