Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/239#discussion_r173080726
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -468,81 +544,107 @@ def balance_sample(schema_madlib, source_table,
output_table, class_col,
parsed_class_sizes = extract_keyvalue_params(class_sizes,
allow_duplicates=False,
lower_case_names=False)
+ distinct_levels = collate_plpy_result(
+ plpy.execute("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
+ format(class_col, source_table)))['levels']
if not parsed_class_sizes:
- sampling_strategy_str =
_validate_and_get_sampling_strategy(class_sizes,
- output_table_size)
+ sampling_strategy_str = _validate_and_get_sampling_strategy(
+ class_sizes, output_table_size)
else:
sampling_strategy_str = None
try:
for each_level, each_class_size in
parsed_class_sizes.items():
- _assert(each_level in actual_level_counts,
+ _assert(each_level in distinct_levels,
"Sample: Invalid class value specified ({0})".
- format(each_level))
+ format(each_level))
each_class_size = int(each_class_size)
_assert(each_class_size >= 1,
"Sample: Class size has to be greater than
zero")
parsed_class_sizes[each_level] = each_class_size
-
- except TypeError:
+ except ValueError:
plpy.error("Sample: Invalid value for class_sizes ({0})".
format(class_sizes))
# Get the number of rows to be sampled for each class level, based
on
# the input table, class_sizes, and output_table_size params. This
also
# includes info about the resulting sampling strategy, i.e., one of
# UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
- target_class_sizes =
_get_target_level_counts(sampling_strategy_str,
- parsed_class_sizes,
- actual_level_counts,
- output_table_size)
-
- undersample_level_dict, oversample_level_dict, nosample_level_dict
= \
- _get_sampling_strategy_specific_dict(target_class_sizes)
-
- # Get subqueries for each sampling strategy, so that they can be
used
- # together in one big query.
-
- # Subquery that will be used to get rows as is for those class
levels
- # that need no sampling.
- nosample_subquery = _get_nosample_subquery(
- new_source_table, class_col, nosample_level_dict.keys())
- # Subquery that will be used to sample those class levels that
- # have to be oversampled.
- oversample_subquery = _get_with_replacement_subquery(
- schema_madlib, new_source_table, source_table_columns,
class_col,
- actual_level_counts, oversample_level_dict)
- # Subquery that will be used to sample those class levels that
- # have to be undersampled. Undersampling supports both with and
without
- # replacement, so fetch the appropriate subquery.
- if with_replacement:
- undersample_subquery = _get_with_replacement_subquery(
- schema_madlib, new_source_table, source_table_columns,
class_col,
- actual_level_counts, undersample_level_dict)
- else:
- undersample_subquery = _get_without_replacement_subquery(
- schema_madlib, new_source_table, source_table_columns,
class_col,
- actual_level_counts, undersample_level_dict)
-
- # Merge the three subqueries using a UNION ALL clause.
- union_all_subquery = ' UNION ALL '.join(
- ['({0})'.format(subquery)
- for subquery in [undersample_subquery, oversample_subquery,
nosample_subquery]
- if subquery])
-
- final_query = """
- CREATE TABLE {output_table} AS
- SELECT row_number() OVER() AS {new_col_name}, *
+ grp_col_str, grp_cols = get_grouping_col_str(
+ schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
+ source_table, grouping_cols)
+ actual_grp_level_counts = _get_level_frequency_distribution(
+ new_source_table, class_col, grp_col_str)
+
+ is_output_created = False
+ n_grp_values = len(actual_grp_level_counts)
+ grp_cols_list = grp_col_str.split(',') if grp_cols else []
+
+ # for each group
+ for grp_vals, actual_level_counts in
actual_grp_level_counts.items():
+ target_class_sizes = _get_target_level_counts(
+ sampling_strategy_str, parsed_class_sizes,
+ actual_level_counts, output_table_size, n_grp_values)
--- End diff --
I've moved the `_get_target_level_counts` outside the loop to catch the
error earlier. Hopefully, all the grouping logic doesn't complicate the code
much.
---