Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/239#discussion_r173080410
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -58,28 +60,64 @@ NOSAMPLE = 'nosample'
     NEW_ID_COLUMN = '__madlib_id__'
     NULL_IDENTIFIER = '__madlib_null_id__'
     
    -def _get_level_frequency_distribution(source_table, class_col):
    -    """ Returns a dict containing the number of rows associated with each 
class
    +
    +def _get_level_frequency_distribution(source_table, class_col,
    +                                      grp_by_cols=None):
    +    """ Count the number of rows for each class, partitioned by the 
grp_by_cols
    +
    +        Returns a dict containing the number of rows associated with each 
class
             level. Each class level count is converted to a string using 
::text.
             None is a valid key in this dict, capturing NULL value in the 
database.
         """
    +    if grp_by_cols and grp_by_cols.lower() != 'null':
    +        is_grouping = True
    +        grp_by_cols_comma = grp_by_cols + ', '
    +        array_grp_by_cols_comma = "array[{0}]".format(grp_by_cols) + " as 
group_values, "
    +    else:
    +        is_grouping = False
    +        grp_by_cols_comma = array_grp_by_cols_comma = ""
    +
    +    # In below query, the inner query groups the data using grp_by_cols + 
classes
    +    # and obtains the count for each combination. The outer query then 
groups
    +    # again by the grp_by_cols to collect the classes and counts in an 
array.
         query_result = plpy.execute("""
    -                    SELECT {class_col}::text AS classes,
    -                           count(*) AS class_count
    -                    FROM {source_table}
    -                    GROUP BY {class_col}
    -                 """.format(**locals()))
    +        SELECT
    +            -- For each group get the classes and their rows counts
    +            {grp_identifier} as group_values,
    +            array_agg(classes) as classes,
    +            array_agg(class_count) as class_count
    +        FROM(
    +            -- for each group and class combination present in source table
    +            -- get the count of rows for that combination
    +            SELECT
    +                {array_grp_by_cols_comma}
    +                ({class_col})::TEXT AS classes,
    +                count(*) AS class_count
    +            FROM {source_table}
    +            GROUP BY {grp_by_cols_comma} ({class_col})
    +        ) q
    +        GROUP BY {grp_identifier}
    --- End diff --
    
    Thanks for checking this! I've removed the constant group by clause to 
avoid this. 


---

Reply via email to