Github user njayaram2 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/239#discussion_r172958587
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -58,28 +60,64 @@ NOSAMPLE = 'nosample'
NEW_ID_COLUMN = '__madlib_id__'
NULL_IDENTIFIER = '__madlib_null_id__'
-def _get_level_frequency_distribution(source_table, class_col):
- """ Returns a dict containing the number of rows associated with each
class
+
+def _get_level_frequency_distribution(source_table, class_col,
+ grp_by_cols=None):
+ """ Count the number of rows for each class, partitioned by the
grp_by_cols
+
+ Returns a dict containing the number of rows associated with each
class
level. Each class level count is converted to a string using
::text.
None is a valid key in this dict, capturing NULL value in the
database.
"""
+ if grp_by_cols and grp_by_cols.lower() != 'null':
+ is_grouping = True
+ grp_by_cols_comma = grp_by_cols + ', '
+ array_grp_by_cols_comma = "array[{0}]".format(grp_by_cols) + " as
group_values, "
+ else:
+ is_grouping = False
+ grp_by_cols_comma = array_grp_by_cols_comma = ""
+
+ # In below query, the inner query groups the data using grp_by_cols +
classes
+ # and obtains the count for each combination. The outer query then
groups
+ # again by the grp_by_cols to collect the classes and counts in an
array.
query_result = plpy.execute("""
- SELECT {class_col}::text AS classes,
- count(*) AS class_count
- FROM {source_table}
- GROUP BY {class_col}
- """.format(**locals()))
+ SELECT
+ -- For each group get the classes and their rows counts
+ {grp_identifier} as group_values,
+ array_agg(classes) as classes,
+ array_agg(class_count) as class_count
+ FROM(
+ -- for each group and class combination present in source table
+ -- get the count of rows for that combination
+ SELECT
+ {array_grp_by_cols_comma}
+ ({class_col})::TEXT AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ GROUP BY {grp_by_cols_comma} ({class_col})
+ ) q
+ GROUP BY {grp_identifier}
--- End diff --
In Greenplum 5.x, the above query fails on install check, with the
following error:
```
SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample ', NULL,
NULL, TRUE, TRUE);
psql:/tmp/madlib.2N5sjK/sample/test/balance_sample.sql_in.tmp:111: ERROR:
plpy.SPIError: non-integer constant in GROUP BY
LINE 17: GROUP BY true
^
QUERY:
SELECT
-- For each group get the classes and their rows counts
true as group_values,
array_agg(classes) as classes,
array_agg(class_count) as class_count
FROM(
-- for each group and class combination present in source table
-- get the count of rows for that combination
SELECT
(gr1)::TEXT AS classes,
count(*) AS class_count
FROM "TEST_s"
GROUP BY (gr1)
) q
GROUP BY true
CONTEXT: Traceback (most recent call last):
PL/Python function "balance_sample", line 23, in <module>
return balance_sample.balance_sample(**globals())
PL/Python function "balance_sample", line 575, in balance_sample
PL/Python function "balance_sample", line 100, in
_get_level_frequency_distribution
PL/Python function "balance_sample"
```
---