Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/239#discussion_r172958587
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -58,28 +60,64 @@ NOSAMPLE = 'nosample'
     NEW_ID_COLUMN = '__madlib_id__'
     NULL_IDENTIFIER = '__madlib_null_id__'
     
    -def _get_level_frequency_distribution(source_table, class_col):
    -    """ Returns a dict containing the number of rows associated with each 
class
    +
    +def _get_level_frequency_distribution(source_table, class_col,
    +                                      grp_by_cols=None):
    +    """ Count the number of rows for each class, partitioned by the 
grp_by_cols
    +
    +        Returns a dict containing the number of rows associated with each 
class
             level. Each class level count is converted to a string using 
::text.
             None is a valid key in this dict, capturing NULL value in the 
database.
         """
    +    if grp_by_cols and grp_by_cols.lower() != 'null':
    +        is_grouping = True
    +        grp_by_cols_comma = grp_by_cols + ', '
    +        array_grp_by_cols_comma = "array[{0}]".format(grp_by_cols) + " as 
group_values, "
    +    else:
    +        is_grouping = False
    +        grp_by_cols_comma = array_grp_by_cols_comma = ""
    +
    +    # In below query, the inner query groups the data using grp_by_cols + 
classes
    +    # and obtains the count for each combination. The outer query then 
groups
    +    # again by the grp_by_cols to collect the classes and counts in an 
array.
         query_result = plpy.execute("""
    -                    SELECT {class_col}::text AS classes,
    -                           count(*) AS class_count
    -                    FROM {source_table}
    -                    GROUP BY {class_col}
    -                 """.format(**locals()))
    +        SELECT
    +            -- For each group get the classes and their rows counts
    +            {grp_identifier} as group_values,
    +            array_agg(classes) as classes,
    +            array_agg(class_count) as class_count
    +        FROM(
    +            -- for each group and class combination present in source table
    +            -- get the count of rows for that combination
    +            SELECT
    +                {array_grp_by_cols_comma}
    +                ({class_col})::TEXT AS classes,
    +                count(*) AS class_count
    +            FROM {source_table}
    +            GROUP BY {grp_by_cols_comma} ({class_col})
    +        ) q
    +        GROUP BY {grp_identifier}
    --- End diff --
    
    In Greenplum 5.x, the above query fails on install check, with the 
following error:
    ```
    SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample ', NULL, 
NULL, TRUE, TRUE);
    psql:/tmp/madlib.2N5sjK/sample/test/balance_sample.sql_in.tmp:111: ERROR:  
plpy.SPIError: non-integer constant in GROUP BY
    LINE 17:         GROUP BY true
                              ^
    QUERY:
            SELECT
                -- For each group get the classes and their rows counts
                true as group_values,
                array_agg(classes) as classes,
                array_agg(class_count) as class_count
            FROM(
                -- for each group and class combination present in source table
                -- get the count of rows for that combination
                SELECT
    
                    (gr1)::TEXT AS classes,
                    count(*) AS class_count
                FROM "TEST_s"
                GROUP BY  (gr1)
            ) q
            GROUP BY true
    
    CONTEXT:  Traceback (most recent call last):
      PL/Python function "balance_sample", line 23, in <module>
        return balance_sample.balance_sample(**globals())
      PL/Python function "balance_sample", line 575, in balance_sample
      PL/Python function "balance_sample", line 100, in 
_get_level_frequency_distribution
    PL/Python function "balance_sample"
    ```


---

Reply via email to