Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/239#discussion_r173080410
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -58,28 +60,64 @@ NOSAMPLE = 'nosample'
NEW_ID_COLUMN = '__madlib_id__'
NULL_IDENTIFIER = '__madlib_null_id__'
-def _get_level_frequency_distribution(source_table, class_col):
- """ Returns a dict containing the number of rows associated with each
class
+
+def _get_level_frequency_distribution(source_table, class_col,
+ grp_by_cols=None):
+ """ Count the number of rows for each class, partitioned by the
grp_by_cols
+
+ Returns a dict containing the number of rows associated with each
class
level. Each class level count is converted to a string using
::text.
None is a valid key in this dict, capturing NULL value in the
database.
"""
+ if grp_by_cols and grp_by_cols.lower() != 'null':
+ is_grouping = True
+ grp_by_cols_comma = grp_by_cols + ', '
+ array_grp_by_cols_comma = "array[{0}]".format(grp_by_cols) + " as
group_values, "
+ else:
+ is_grouping = False
+ grp_by_cols_comma = array_grp_by_cols_comma = ""
+
+ # In below query, the inner query groups the data using grp_by_cols +
classes
+ # and obtains the count for each combination. The outer query then
groups
+ # again by the grp_by_cols to collect the classes and counts in an
array.
query_result = plpy.execute("""
- SELECT {class_col}::text AS classes,
- count(*) AS class_count
- FROM {source_table}
- GROUP BY {class_col}
- """.format(**locals()))
+ SELECT
+ -- For each group get the classes and their rows counts
+ {grp_identifier} as group_values,
+ array_agg(classes) as classes,
+ array_agg(class_count) as class_count
+ FROM(
+ -- for each group and class combination present in source table
+ -- get the count of rows for that combination
+ SELECT
+ {array_grp_by_cols_comma}
+ ({class_col})::TEXT AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ GROUP BY {grp_by_cols_comma} ({class_col})
+ ) q
+ GROUP BY {grp_identifier}
--- End diff --
Thanks for checking this! I've removed the constant group by clause to
avoid this.
---