Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/239#discussion_r173080726
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -468,81 +544,107 @@ def balance_sample(schema_madlib, source_table, 
output_table, class_col,
             parsed_class_sizes = extract_keyvalue_params(class_sizes,
                                                          
allow_duplicates=False,
                                                          
lower_case_names=False)
    +        distinct_levels = collate_plpy_result(
    +            plpy.execute("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
    +                         format(class_col, source_table)))['levels']
             if not parsed_class_sizes:
    -            sampling_strategy_str = 
_validate_and_get_sampling_strategy(class_sizes,
    -                                        output_table_size)
    +            sampling_strategy_str = _validate_and_get_sampling_strategy(
    +                class_sizes, output_table_size)
             else:
                 sampling_strategy_str = None
                 try:
                     for each_level, each_class_size in 
parsed_class_sizes.items():
    -                    _assert(each_level in actual_level_counts,
    +                    _assert(each_level in distinct_levels,
                                 "Sample: Invalid class value specified ({0})".
    -                                       format(each_level))
    +                            format(each_level))
                         each_class_size = int(each_class_size)
                         _assert(each_class_size >= 1,
                                 "Sample: Class size has to be greater than 
zero")
                         parsed_class_sizes[each_level] = each_class_size
    -
    -            except TypeError:
    +            except ValueError:
                     plpy.error("Sample: Invalid value for class_sizes ({0})".
                                format(class_sizes))
     
             # Get the number of rows to be sampled for each class level, based 
on
             # the input table, class_sizes, and output_table_size params. This 
also
             # includes info about the resulting sampling strategy, i.e., one of
             # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
    -        target_class_sizes = 
_get_target_level_counts(sampling_strategy_str,
    -                                                      parsed_class_sizes,
    -                                                      actual_level_counts,
    -                                                      output_table_size)
    -
    -        undersample_level_dict, oversample_level_dict, nosample_level_dict 
= \
    -            _get_sampling_strategy_specific_dict(target_class_sizes)
    -
    -        # Get subqueries for each sampling strategy, so that they can be 
used
    -        # together in one big query.
    -
    -        # Subquery that will be used to get rows as is for those class 
levels
    -        # that need no sampling.
    -        nosample_subquery = _get_nosample_subquery(
    -            new_source_table, class_col, nosample_level_dict.keys())
    -        # Subquery that will be used to sample those class levels that
    -        # have to be oversampled.
    -        oversample_subquery = _get_with_replacement_subquery(
    -            schema_madlib, new_source_table, source_table_columns, 
class_col,
    -            actual_level_counts, oversample_level_dict)
    -        # Subquery that will be used to sample those class levels that
    -        # have to be undersampled. Undersampling supports both with and 
without
    -        # replacement, so fetch the appropriate subquery.
    -        if with_replacement:
    -            undersample_subquery = _get_with_replacement_subquery(
    -                schema_madlib, new_source_table, source_table_columns, 
class_col,
    -                actual_level_counts, undersample_level_dict)
    -        else:
    -            undersample_subquery = _get_without_replacement_subquery(
    -                schema_madlib, new_source_table, source_table_columns, 
class_col,
    -                actual_level_counts, undersample_level_dict)
    -
    -        # Merge the three subqueries using a UNION ALL clause.
    -        union_all_subquery = ' UNION ALL '.join(
    -            ['({0})'.format(subquery)
    -             for subquery in [undersample_subquery, oversample_subquery, 
nosample_subquery]
    -             if subquery])
    -
    -        final_query = """
    -                CREATE TABLE {output_table} AS
    -                SELECT row_number() OVER() AS {new_col_name}, *
    +        grp_col_str, grp_cols = get_grouping_col_str(
    +            schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
    +            source_table, grouping_cols)
    +        actual_grp_level_counts = _get_level_frequency_distribution(
    +            new_source_table, class_col, grp_col_str)
    +
    +        is_output_created = False
    +        n_grp_values = len(actual_grp_level_counts)
    +        grp_cols_list = grp_col_str.split(',') if grp_cols else []
    +
    +        # for each group
    +        for grp_vals, actual_level_counts in 
actual_grp_level_counts.items():
    +            target_class_sizes = _get_target_level_counts(
    +                sampling_strategy_str, parsed_class_sizes,
    +                actual_level_counts, output_table_size, n_grp_values)
    --- End diff --
    
    I've moved the `_get_target_level_counts` outside the loop to catch the 
error earlier. Hopefully, all the grouping logic doesn't complicate the code 
much. 


---

Reply via email to