Repository: madlib Updated Branches: refs/heads/master 3af2d703e -> 0bfcaf5c4
Balance sample: Add grouping support JIRA: MADLIB-1168 This commit adds grouping support for balanced sampling. Grouping is implemented as a loop over the existing logic, with the sampling for each group run independently. Additional changes: - Ensure bool_to_text returns 'true' instead of 't' and 'false' instead of 'f'. - Use bool_to_text only for platforms that don't have bool-to-text cast. - Add a Collate plpy results function Closes #239 Project: http://git-wip-us.apache.org/repos/asf/madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0bfcaf5c Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0bfcaf5c Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0bfcaf5c Branch: refs/heads/master Commit: 0bfcaf5c4623224fdcabc75a72dff21c70bfb1f6 Parents: 3af2d70 Author: Rahul Iyer <ri...@apache.org> Authored: Tue Mar 13 12:59:04 2018 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Tue Mar 13 13:48:46 2018 -0700 ---------------------------------------------------------------------- src/ports/greenplum/cmake/GreenplumUtils.cmake | 4 + src/ports/postgres/cmake/PostgreSQLUtils.cmake | 1 + .../modules/sample/balance_sample.py_in | 558 +++++++++++-------- .../modules/sample/balance_sample.sql_in | 99 +++- .../modules/sample/test/balance_sample.sql_in | 174 +++--- .../postgres/modules/utilities/utilities.py_in | 40 +- .../postgres/modules/utilities/utilities.sql_in | 9 +- .../modules/utilities/validate_args.py_in | 3 + 8 files changed, 583 insertions(+), 305 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/greenplum/cmake/GreenplumUtils.cmake ---------------------------------------------------------------------- diff --git a/src/ports/greenplum/cmake/GreenplumUtils.cmake b/src/ports/greenplum/cmake/GreenplumUtils.cmake index 9e3aa21..0fc1637 100644 --- a/src/ports/greenplum/cmake/GreenplumUtils.cmake +++ b/src/ports/greenplum/cmake/GreenplumUtils.cmake @@ -9,6 +9,10 @@ function(define_greenplum_features IN_VERSION OUT_FEATURES) list(APPEND ${OUT_FEATURES} __HAS_FUNCTION_PROPERTIES__) endif() + if(${IN_VERSION} VERSION_GREATER "4.3") + list(APPEND ${OUT_FEATURES} __HAS_BOOL_TO_TEXT_CAST__) + endif() + # Pass values to caller set(${OUT_FEATURES} "${${OUT_FEATURES}}" PARENT_SCOPE) endfunction(define_greenplum_features) http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/cmake/PostgreSQLUtils.cmake ---------------------------------------------------------------------- diff --git a/src/ports/postgres/cmake/PostgreSQLUtils.cmake b/src/ports/postgres/cmake/PostgreSQLUtils.cmake index 30962bb..11109a6 100644 --- a/src/ports/postgres/cmake/PostgreSQLUtils.cmake +++ b/src/ports/postgres/cmake/PostgreSQLUtils.cmake @@ -3,6 +3,7 @@ function(define_postgresql_features IN_VERSION OUT_FEATURES) if(NOT ${IN_VERSION} VERSION_LESS "9.0") list(APPEND ${OUT_FEATURES} __HAS_ORDERED_AGGREGATES__) + list(APPEND ${OUT_FEATURES} __HAS_BOOL_TO_TEXT_CAST__) endif() # Pass values to caller http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/balance_sample.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/balance_sample.py_in b/src/ports/postgres/modules/sample/balance_sample.py_in index a32e8e7..28cd11c 100644 --- a/src/ports/postgres/modules/sample/balance_sample.py_in +++ b/src/ports/postgres/modules/sample/balance_sample.py_in @@ -17,9 +17,10 @@ # specific language governing permissions and limitations # under the License. -m4_changequote(`<!', `!>') +# m4_changequote(`<!', `!>') import math +from collections import defaultdict if __name__ != "__main__": import plpy @@ -27,7 +28,10 @@ if __name__ != "__main__": from utilities.utilities import _assert from utilities.utilities import extract_keyvalue_params from utilities.utilities import unique_string + from utilities.utilities import collate_plpy_result + from utilities.utilities import get_grouping_col_str from utilities.validate_args import columns_exist_in_table + from utilities.validate_args import explicit_bool_to_text from utilities.validate_args import get_cols from utilities.validate_args import table_exists from utilities.validate_args import table_is_empty @@ -58,28 +62,65 @@ NOSAMPLE = 'nosample' NEW_ID_COLUMN = '__madlib_id__' NULL_IDENTIFIER = '__madlib_null_id__' -def _get_level_frequency_distribution(source_table, class_col): - """ Returns a dict containing the number of rows associated with each class + +def _get_level_frequency_distribution(source_table, class_col, + grp_by_cols=None): + """ Count the number of rows for each class, partitioned by the grp_by_cols + + Returns a dict containing the number of rows associated with each class level. Each class level count is converted to a string using ::text. None is a valid key in this dict, capturing NULL value in the database. """ + if grp_by_cols and grp_by_cols.lower() != 'null': + is_grouping = True + grp_by_cols_comma = grp_by_cols + ', ' + array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS group_values, " + else: + is_grouping = False + grp_by_cols_comma = array_grp_by_cols_comma = "" + + # In below query, the inner query groups the data using grp_by_cols + classes + # and obtains the count for each combination. The outer query then groups + # again by the grp_by_cols to collect the classes and counts in an array. query_result = plpy.execute(""" - SELECT {class_col}::text AS classes, - count(*) AS class_count - FROM {source_table} - GROUP BY {class_col} - """.format(**locals())) - actual_level_counts = {} + SELECT + -- For each group get the classes and their rows counts + {grp_identifier} as group_values, + array_agg(classes) as classes, + array_agg(class_count) as class_count + FROM( + -- for each group and class combination present in source table + -- get the count of rows for that combination + SELECT + {array_grp_by_cols_comma} + ({class_col})::TEXT AS classes, + count(*) AS class_count + FROM {source_table} + GROUP BY {grp_by_cols_comma} ({class_col}) + ) q + {meta_grp_by} + """.format(grp_identifier="group_values" if is_grouping else "NULL", + meta_grp_by="GROUP BY group_values" if is_grouping else "", + **locals())) + if (len(query_result) > 1) != is_grouping: + # if is_grouping then query_result should have more than 1 row + # if not is_grouping then query_result should have only 1 row + raise RuntimeError("Balance sample: Error during frequency level distribution") + + actual_grp_level_counts = {} for each_row in query_result: - level = each_row['classes'] - if level: - level = level.strip() - actual_level_counts[level] = each_row['class_count'] - return actual_level_counts + # group_values is a list for each row; convert it to a tuple to use as + # key in a dictionary + grp = tuple(each_row['group_values']) if is_grouping else None + grp_levels, grp_counts = each_row['classes'], each_row['class_count'] + actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts)) + return actual_grp_level_counts +# ------------------------------------------------------------------------------ -def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size, - supported_strategies=None, default=UNIFORM): +def _validate_and_get_sampling_strategy(sampling_strategy_str, + output_table_size, + default=UNIFORM): """ Returns the sampling strategy based on the class_sizes input param. @param sampling_strategy_str The sampling strategy specified by the user (class_sizes param) @@ -94,8 +135,7 @@ def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size # common prefix substring plpy.error("Sample: Invalid class_sizes parameter") - if not supported_strategies: - supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE] + supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE) try: # allow user to specify a prefix substring of # supported strategies. @@ -104,14 +144,16 @@ def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size except StopIteration: # next() returns a StopIteration if no element found plpy.error("Sample: Invalid class_sizes parameter: " - "{0}. Supported class_size parameters are ({1})" - .format(sampling_strategy_str, ','.join(sorted(supported_strategies)))) + "{0}. Supported class_size parameters are ({1})". + format(sampling_strategy_str, + ','.join(sorted(supported_strategies)))) _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or (sampling_strategy_str.find('=') > 0), "Sample: Invalid class_sizes parameter: " - "{0}. Supported class_size parameters are ({1})" - .format(sampling_strategy_str, ','.join(sorted(supported_strategies)))) + "{0}. Supported class_size parameters are ({1})". + format(sampling_strategy_str, + ','.join(sorted(supported_strategies)))) _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size), "Sample: Cannot set output_table_size with oversampling.") @@ -147,47 +189,75 @@ def _choose_strategy(actual_count, desired_count): return UNDERSAMPLE # ------------------------------------------------------------------------- -def _get_desired_target_level_counts(desired_level_counts, actual_level_counts, - output_table_size): - """ Returns the target level counts for each class, based on the user - defined class_size string. The strategy of (under)oversampling for each - class level is chosen based on the desired number of counts for a level, - and the actual number of counts already present in the input table. +def _get_desired_target_level_counts(desired_level_counts, + actual_grp_level_counts, + output_table_size): + """ Return the target counts for each group and each class level in the group + + This function is specifically used when the user has provided the + targets for all (or a subset) of the levels. The strategy of either + under or oversampling for each class level is chosen based on the + desired number of counts for a level, and the actual number of counts + already present in the input table. This calculation is performed for + each group (or once if no grouping is used). + + @returns: dict: The key of the dictionary is the group value and the + value is another dictionary. The inner dictionary gives + the target count for each level present in the group. """ - target_level_counts = {} - for each_level, desired_count in desired_level_counts.items(): - sample_strategy = _choose_strategy(actual_level_counts[each_level], - desired_count) - target_level_counts[each_level] = (desired_count, sample_strategy) - - remaining_levels = (set(actual_level_counts.keys()) - - set(desired_level_counts.keys())) - # if 'output_table_size' = NULL, unspecified level counts remain as is - # if 'output_table_size' = <Integer>, divide remaining row count - # uniformly among unspecified level counts - if output_table_size: - # Uniformly distribute across the remaining class levels - remaining_rows = output_table_size - sum(desired_level_counts.values()) - if remaining_rows > 0: - rows_per_level = math.ceil(float(remaining_rows) / - len(remaining_levels)) + target_grp_level_counts = defaultdict(dict) + n_grp_values = len(actual_grp_level_counts) + for each_grp, actual_level_counts in actual_grp_level_counts.items(): + for each_level, desired_count in desired_level_counts.items(): + # actual_grp_level_counts are group specific, whereas + # desired_level_counts are evenly split among all groups + try: + per_grp_desired_count = math.ceil(float(desired_count) / n_grp_values) + sample_strategy = _choose_strategy(actual_level_counts[each_level], + per_grp_desired_count) + target_grp_level_counts[each_grp][each_level] = ( + per_grp_desired_count, sample_strategy) + except KeyError: + plpy.error("Balance sample: Desired class level ({0}) not present in the " + "data for each group.".format(each_level)) + + # desired levels could contain just a subset of all levels. For the remaining + # levels in actual_level_counts, compute the desired counts + remaining_levels = (set(actual_level_counts.keys()) - + set(desired_level_counts.keys())) + + # if 'output_table_size' = NULL, remaining level counts remain as is + # if 'output_table_size' = <Integer>, divide remaining count + # uniformly among reamining levels + if output_table_size: + # output_table_size is for the whole table and should be split evenly + # between the groups + remaining_rows = math.ceil(float(output_table_size - + sum(desired_level_counts.values())) / + n_grp_values) + if remaining_rows > 0: + # Uniformly distribute the remaining class levels + rows_per_level = math.ceil(float(remaining_rows) / + len(remaining_levels)) + for each_level in remaining_levels: + sample_strategy = _choose_strategy( + actual_level_counts[each_level], rows_per_level) + target_grp_level_counts[each_grp][each_level] = ( + rows_per_level, sample_strategy) + else: + # When output_table_size is unspecified, rows from the input table + # are sampled as is for remaining class levels. This is called as the + # NOSAMPLE strategy. for each_level in remaining_levels: - sample_strategy = _choose_strategy( - actual_level_counts[each_level], rows_per_level) - target_level_counts[each_level] = (rows_per_level, - sample_strategy) - else: - # When output_table_size is unspecified, rows from the input table - # are sampled as is for remaining class levels. This is same as the - # NOSAMPLE strategy. - for each_level in remaining_levels: - target_level_counts[each_level] = (actual_level_counts[each_level], - NOSAMPLE) - return target_level_counts + target_grp_level_counts[each_grp][each_level] = ( + actual_level_counts[each_level], NOSAMPLE) + return target_grp_level_counts # ------------------------------------------------------------------------- -def _get_supported_target_level_counts(sampling_strategy_str, actual_level_counts, + +def _get_supported_target_level_counts(sampling_strategy_str, + actual_grp_level_counts, output_table_size): """ Returns the target level counts for all levels when the class_size param is one of [uniform, undersample, oversample]. The strategy of @@ -196,33 +266,40 @@ def _get_supported_target_level_counts(sampling_strategy_str, actual_level_count already present in the input table. """ - def ceil_of_mean(numbers): - return math.ceil(float(sum(numbers)) / max(len(numbers), 1)) - - target_level_counts = {} - # UNIFORM: Ensure all level counts are same (size determined by output_table_size) - # UNDERSAMPLE: Ensure all level counts are same as the minimum count - # OVERSAMPLE: Ensure all level counts are same as the maximum count - size_function = {UNDERSAMPLE: min, - OVERSAMPLE: max, - UNIFORM: ceil_of_mean - }[sampling_strategy_str] - if sampling_strategy_str == UNIFORM and output_table_size: - # Ignore actual counts for computing target sizes - # if output_table_size is specified - target_size_per_level = math.ceil(float(output_table_size) / - len(actual_level_counts)) - else: - target_size_per_level = size_function(actual_level_counts.values()) - for each_level, actual_count in actual_level_counts.items(): - sample_strategy = _choose_strategy(actual_count, target_size_per_level) - target_level_counts[each_level] = (target_size_per_level, - sample_strategy) - return target_level_counts + target_grp_level_counts = defaultdict(dict) + n_grp_values = len(actual_grp_level_counts) + for each_grp, actual_level_counts in actual_grp_level_counts.items(): + if sampling_strategy_str == UNIFORM: + # UNIFORM: Ensure all level counts are same + if output_table_size: + # Ignore actual counts for computing target sizes + # if output_table_size is specified + total_ = float(output_table_size) / n_grp_values + else: + total_ = sum(actual_level_counts.values()) + target_size_per_level = math.ceil(float(total_) / + len(actual_level_counts)) + else: + # UNDERSAMPLE: Ensure all level counts are same as the minimum count + # OVERSAMPLE: Ensure all level counts are same as the maximum count + if sampling_strategy_str == UNDERSAMPLE: + target_size_per_level = min(actual_level_counts.values()) + elif sampling_strategy_str == OVERSAMPLE: + target_size_per_level = max(actual_level_counts.values()) + else: + raise RuntimeError("Balance sample: Invalid " + "sampling_strategy_str encountered") + + for each_level, actual_count in actual_level_counts.items(): + sample_strategy = _choose_strategy(actual_count, target_size_per_level) + target_grp_level_counts[each_grp][each_level] = ( + target_size_per_level, sample_strategy) + return target_grp_level_counts # ------------------------------------------------------------------------- + def _get_target_level_counts(sampling_strategy_str, desired_level_counts, - actual_level_counts, output_table_size): + actual_grp_level_counts, output_table_size): """ @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None]. This is 'None' only if this is user-defined, i.e., @@ -233,8 +310,9 @@ def _get_target_level_counts(sampling_strategy_str, desired_level_counts, then contain the class levels and the corresponding number of rows specified by the user. - @param actual_level_counts: Dict of various class levels and number of rows - in each of them in the input table + @param actual_grp_level_counts: Dictionary that provides for each group the + the count of number of rows for each class + present in the group @param output_table_size: Size of the desired output table (NULL or Integer) @returns: @@ -244,60 +322,64 @@ def _get_target_level_counts(sampling_strategy_str, desired_level_counts, if not sampling_strategy_str: # This case implies user has provided a desired count for one or more - # levels. Counts for the rest of the levels depend on 'output_table_size'. - target_level_counts = _get_desired_target_level_counts(desired_level_counts, - actual_level_counts, - output_table_size) + # levels. Counts for rest of the levels depend on 'output_table_size'. + target_level_counts = _get_desired_target_level_counts( + desired_level_counts, actual_grp_level_counts, output_table_size) else: - # This case imples the user has chosen one of [uniform, undersample, - # oversample] for the class_size param. - target_level_counts = _get_supported_target_level_counts(sampling_strategy_str, - actual_level_counts, - output_table_size) + # This case imples the user has chosen one of + # [uniform, undersample, oversample] for the class_size parameter. + target_level_counts = _get_supported_target_level_counts( + sampling_strategy_str, actual_grp_level_counts, output_table_size) return target_level_counts - # ------------------------------------------------------------------------- -def _get_sampling_strategy_specific_dict(target_class_sizes): +def _get_sampling_strategy_counts(target_class_sizes): """ Return three dicts, one each for undersampling, oversampling, and nosampling. The dict contains the number of samples to be drawn for each class level. """ - undersample_level_dict = {} - oversample_level_dict = {} - nosample_level_dict = {} + undersample_level_counts = {} + oversample_level_counts = {} + nosample_level_counts = {} for level, (count, strategy) in target_class_sizes.items(): if strategy == UNDERSAMPLE: - undersample_level_dict[level] = count + undersample_level_counts[level] = count elif strategy == OVERSAMPLE: - oversample_level_dict[level] = count + oversample_level_counts[level] = count else: - nosample_level_dict[level] = count - return (undersample_level_dict, oversample_level_dict, nosample_level_dict) + nosample_level_counts[level] = count + return (undersample_level_counts, oversample_level_counts, nosample_level_counts) # ------------------------------------------------------------------------------ -def _get_nosample_subquery(source_table, class_col, nosample_levels): +def _get_nosample_subquery(source_table, class_col, nosample_levels, + grp_dict=None): """ Return the subquery for fetching all rows as is from the input table for specific class levels. """ if not nosample_levels: return '' - subquery = """ - SELECT * - FROM {0} - WHERE {1} in ({2}) OR {1} IS NULL - """.format(source_table, class_col, - ','.join(["'{0}'".format(level) - for level in nosample_levels if level])) - return subquery + nosample_level_str = ','.join(["'{0}'".format(level) + for level in nosample_levels if level]) + if grp_dict: + grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v) + for k, v in grp_dict.items()) + else: + grp_filter = '' + return """ + SELECT * + FROM {source_table} + WHERE ({class_col} in ({nosample_level_str}) OR + {class_col} IS NULL) + {grp_filter} + """.format(**locals()) # ------------------------------------------------------------------------------ -def _get_without_replacement_subquery(schema_madlib, source_table, - source_table_columns, class_col, - actual_level_counts, desired_level_counts): +def _get_without_replacement_subquery(schema_madlib, source_table, class_col, + actual_level_counts, desired_level_counts, + grp_dict=None): """ Return the subquery for sampling without replacement for specific class levels. """ @@ -306,47 +388,50 @@ def _get_without_replacement_subquery(schema_madlib, source_table, class_col_tmp = unique_string(desp='class_col') row_number_col = unique_string(desp='row_number') desired_count_col = unique_string(desp='desired_count') - + source_table_columns = ','.join(get_cols(source_table)) null_value_string = "'{0}'".format(NULL_IDENTIFIER) - - desired_level_counts_str = "VALUES " + \ - ','.join("({0}, {1})". - format("'{0}'::text".format(k) if k else null_value_string, v) - for k, v in desired_level_counts.items()) + desired_level_count_pairs = (','.join("({0}, {1})". + format("'{0}'::text".format(k) if k else null_value_string, v) + for k, v in desired_level_counts.items())) + desired_level_counts_str = "VALUES " + desired_level_count_pairs # Subquery q2 is used to figure out the number of rows to select for each # class level. That is used with row_number to order the input rows randomly, - # and, trim the results to the_desired_number_of_rows for each class level. - # q2: + # and trim the results to the desired number of rows for each class level. + # q1: # The FROM clause contains information that can be used to figure out the # number of rows to generate for a class level. The tuple basically # contains: (class_level, the_desired_number_of_rows) + if grp_dict: + grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v) + for k, v in grp_dict.items()) + else: + grp_filter = '' subquery = """ SELECT {source_table_columns} - FROM - ( + FROM ( SELECT {source_table_columns}, row_number() OVER (PARTITION BY {class_col} ORDER BY random()) AS {row_number_col}, {desired_count_col} - FROM - ( + FROM ( SELECT {source_table_columns}, {desired_count_col} FROM {source_table} s, ({desired_level_counts_str}) - q({class_col_tmp}, {desired_count_col}) - WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') + q1({class_col_tmp}, {desired_count_col}) + WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_identifier}') + {grp_filter} ) q2 ) q3 WHERE {row_number_col} <= {desired_count_col} - """.format(null_level_val=NULL_IDENTIFIER, **locals()) + """.format(null_identifier=NULL_IDENTIFIER, **locals()) return subquery # ------------------------------------------------------------------------------ -def _get_with_replacement_subquery(schema_madlib, source_table, - source_table_columns, class_col, - actual_level_counts, desired_level_counts): +def _get_with_replacement_subquery(schema_madlib, source_table, class_col, + actual_level_counts, desired_level_counts, + grp_dict=None): """ Return the query for sampling with replacement for specific class levels. Always used for oversampling since oversampling will always need to use replacement. Used for under sampling only if with_replacement @@ -355,6 +440,7 @@ def _get_with_replacement_subquery(schema_madlib, source_table, if not desired_level_counts: return '' + source_table_columns = ','.join(get_cols(source_table)) class_col_tmp = unique_string(desp='class_col_with_rep') desired_count_col = unique_string(desp='desired_count_with_rep') actual_count_col = unique_string(desp='actual_count') @@ -362,12 +448,12 @@ def _get_with_replacement_subquery(schema_madlib, source_table, q2_row_no = unique_string(desp='q2_row') null_value_string = "'{0}'".format(NULL_IDENTIFIER) + desired_and_actual_count = (','.join("({0}, {1}, {2})". + format("'{0}'::text".format(k) if k else null_value_string, + v, actual_level_counts[k]) + for k, v in desired_level_counts.items())) + desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count - desired_and_actual_level_counts = "VALUES " + \ - ','.join("({0}, {1}, {2})". - format("'{0}'::text".format(k) if k else null_value_string, - v, actual_level_counts[k]) - for k, v in desired_level_counts.items()) # q1 and q2 are two sub queries we create to generate the required number of # rows per class level. # q1: @@ -382,6 +468,11 @@ def _get_with_replacement_subquery(schema_madlib, source_table, # actual_rows_for_level_in_input_table. # # The WHERE clause is used to join the two subqueries to obtain the result. + if grp_dict: + grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v) + for k, v in grp_dict.items()) + else: + grp_filter = '' subquery = """ SELECT {source_table_columns} FROM @@ -391,7 +482,7 @@ def _get_with_replacement_subquery(schema_madlib, source_table, generate_series(1, {desired_count_col}::int) AS _i, ((random()*({actual_count_col}-1)+1)::int) AS {q1_row_no} FROM - ({desired_and_actual_level_counts}) + ({desired_and_actual_level_count_str}) q({class_col_tmp}, {desired_count_col}, {actual_count_col}) ) q1, ( @@ -400,6 +491,7 @@ def _get_with_replacement_subquery(schema_madlib, source_table, row_number() OVER(PARTITION BY {class_col}) AS {q2_row_no} FROM {source_table} + {grp_filter} ) q2 WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}') AND q1.{q1_row_no} = q2.{q2_row_no} @@ -407,6 +499,7 @@ def _get_with_replacement_subquery(schema_madlib, source_table, return subquery # ------------------------------------------------------------------------------ + def balance_sample(schema_madlib, source_table, output_table, class_col, class_sizes, output_table_size, grouping_cols, with_replacement, keep_null, **kwargs): @@ -443,20 +536,18 @@ def balance_sample(schema_madlib, source_table, output_table, class_col, _validate_strs(source_table, output_table, class_col, output_table_size, grouping_cols) - source_table_columns = ','.join(get_cols(source_table)) - new_source_table = source_table # If keep_null=False, create a view of the input table ignoring NULL # values for class levels. - if not keep_null: + if keep_null: + new_source_table = source_table + else: new_source_table = unique_string(desp='source_table') plpy.execute(""" - CREATE VIEW {new_source_table} AS - SELECT * FROM {source_table} - WHERE {class_col} IS NOT NULL - """.format(**locals())) - actual_level_counts = _get_level_frequency_distribution(new_source_table, - class_col) + CREATE VIEW {new_source_table} AS + SELECT * FROM {source_table} + WHERE {class_col} IS NOT NULL + """.format(**locals())) # class_sizes can be of two forms: # 1. A string describing sampling strategy (as described in # _validate_and_get_sampling_strategy). @@ -468,22 +559,34 @@ def balance_sample(schema_madlib, source_table, output_table, class_col, parsed_class_sizes = extract_keyvalue_params(class_sizes, allow_duplicates=False, lower_case_names=False) + + # GREENPLUM 4.3.X does not support bool-to-text cast which is relied + # upon for class_col in multiple queries. The explicit_bool_to_text + # function wraps the class_col with a MADlib function that provides the + # cast just for those platforms that don't provide the cast. For + # platforms that provide the cast, class_col is unchanged below. + class_col = explicit_bool_to_text(source_table, + [class_col], + schema_madlib)[0] + distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ". + format(class_col, + source_table)) + distinct_levels = collate_plpy_result(plpy.execute(distinct_sql))['levels'] if not parsed_class_sizes: - sampling_strategy_str = _validate_and_get_sampling_strategy(class_sizes, - output_table_size) + sampling_strategy_str = _validate_and_get_sampling_strategy( + class_sizes, output_table_size) else: sampling_strategy_str = None try: for each_level, each_class_size in parsed_class_sizes.items(): - _assert(each_level in actual_level_counts, + _assert(each_level in distinct_levels, "Sample: Invalid class value specified ({0})". - format(each_level)) + format(each_level)) each_class_size = int(each_class_size) _assert(each_class_size >= 1, "Sample: Class size has to be greater than zero") parsed_class_sizes[each_level] = each_class_size - - except TypeError: + except ValueError: plpy.error("Sample: Invalid value for class_sizes ({0})". format(class_sizes)) @@ -491,58 +594,82 @@ def balance_sample(schema_madlib, source_table, output_table, class_col, # the input table, class_sizes, and output_table_size params. This also # includes info about the resulting sampling strategy, i.e., one of # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level. - target_class_sizes = _get_target_level_counts(sampling_strategy_str, - parsed_class_sizes, - actual_level_counts, - output_table_size) - - undersample_level_dict, oversample_level_dict, nosample_level_dict = \ - _get_sampling_strategy_specific_dict(target_class_sizes) - - # Get subqueries for each sampling strategy, so that they can be used - # together in one big query. - - # Subquery that will be used to get rows as is for those class levels - # that need no sampling. - nosample_subquery = _get_nosample_subquery( - new_source_table, class_col, nosample_level_dict.keys()) - # Subquery that will be used to sample those class levels that - # have to be oversampled. - oversample_subquery = _get_with_replacement_subquery( - schema_madlib, new_source_table, source_table_columns, class_col, - actual_level_counts, oversample_level_dict) - # Subquery that will be used to sample those class levels that - # have to be undersampled. Undersampling supports both with and without - # replacement, so fetch the appropriate subquery. - if with_replacement: - undersample_subquery = _get_with_replacement_subquery( - schema_madlib, new_source_table, source_table_columns, class_col, - actual_level_counts, undersample_level_dict) - else: - undersample_subquery = _get_without_replacement_subquery( - schema_madlib, new_source_table, source_table_columns, class_col, - actual_level_counts, undersample_level_dict) - - # Merge the three subqueries using a UNION ALL clause. - union_all_subquery = ' UNION ALL '.join( - ['({0})'.format(subquery) - for subquery in [undersample_subquery, oversample_subquery, nosample_subquery] - if subquery]) - - final_query = """ - CREATE TABLE {output_table} AS - SELECT row_number() OVER() AS {new_col_name}, * + grp_col_str, grp_cols = get_grouping_col_str( + schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col], + source_table, grouping_cols) + actual_grp_level_counts = _get_level_frequency_distribution( + new_source_table, class_col, grp_col_str) + target_grp_class_sizes = _get_target_level_counts( + sampling_strategy_str, parsed_class_sizes, + actual_grp_level_counts, output_table_size) + + is_output_created = False + grp_cols_list = grp_col_str.split(',') if grp_cols else [] + + # for each group + for grp_vals, actual_level_counts in actual_grp_level_counts.items(): + target_class_sizes = target_grp_class_sizes[grp_vals] + + (undersample_level_counts, + oversample_level_counts, + nosample_level_counts) = _get_sampling_strategy_counts(target_class_sizes) + + # grp_dict represents each grouping column and its value + # for current iteration + grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None + + # Get subqueries for each sampling strategy and union together in + # one big query. + + # NOSAMPLE are levels that are to be retained in output without sampling + nosample_subquery = _get_nosample_subquery( + new_source_table, class_col, nosample_level_counts.keys(), grp_dict) + # OVERSAMPLE are levels that to be sampled more than existing count + # (always with replacement) + oversample_subquery = _get_with_replacement_subquery( + schema_madlib, new_source_table, class_col, + actual_level_counts, oversample_level_counts, grp_dict) + # UNDERSAMPLE are levels that are to be sampled to less than + # existing count. Undersampling supports both with and without + # replacement. + if with_replacement: + undersample_subquery = _get_with_replacement_subquery( + schema_madlib, new_source_table, class_col, + actual_level_counts, undersample_level_counts, grp_dict) + else: + undersample_subquery = _get_without_replacement_subquery( + schema_madlib, new_source_table, class_col, + actual_level_counts, undersample_level_counts, grp_dict) + + # Merge the three subqueries using a UNION ALL clause. + sampling_queries = (undersample_subquery, oversample_subquery, nosample_subquery) + union_all_subquery = ' UNION ALL '.join( + ['({0})'.format(subquery) + for subquery in sampling_queries if subquery]) + + # Populate the output table + if is_output_created: + table_header = "INSERT INTO {0}".format(output_table) + else: + table_header = "CREATE TABLE {0} AS ".format(output_table) + is_output_created = True + final_query = """ + {table_header} + SELECT row_number() OVER() AS {id_col_name}, * FROM ( {union_all_subquery} ) union_query - """.format(new_col_name=NEW_ID_COLUMN, **locals()) - plpy.execute(final_query) + """.format(id_col_name=NEW_ID_COLUMN, **locals()) + plpy.execute(final_query) + # end of grouping loop + if not keep_null: - plpy.execute("DROP VIEW {0}".format(new_source_table)) + plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table)) +# ------------------------------------------------------------------------------ -def _validate_strs(source_table, output_table, class_col, output_table_size, - grouping_cols): +def _validate_strs(source_table, output_table, class_col, + output_table_size, grouping_cols): _assert(source_table and table_exists(source_table), "Sample: Source table ({source_table}) does not exist.".format(**locals())) _assert(not table_is_empty(source_table), @@ -566,10 +693,7 @@ def _validate_strs(source_table, output_table, class_col, output_table_size, _assert((not output_table_size) or (output_table_size > 0), "Sample: Invalid output table size ({output_table_size}).".format( **locals())) - - _assert(grouping_cols is None, - "grouping_cols is not supported at the moment." - .format(**locals())) +# ------------------------------------------------------------------------------ def balance_sample_help(schema_madlib, message, **kwargs): @@ -715,7 +839,7 @@ class UtilitiesTestCase(unittest.TestCase): """ def setUp(self): - self.input_class_level_counts1 = {'a': 20, 'b': 30, 'c': 25} + self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}} self.level1a = 'a' self.level1a_cnt1 = 15 self.level1a_cnt2 = 25 @@ -740,55 +864,55 @@ class UtilitiesTestCase(unittest.TestCase): def test__get_target_level_counts(self): # Test cases for user defined class level samples, without output table size - self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}, + self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size1, self.input_class_level_counts1, self.output_table_size1)) - self.assertEqual({'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}, + self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, NOSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size2, self.input_class_level_counts1, self.output_table_size1)) - self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}, + self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': (25, NOSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size3, self.input_class_level_counts1, self.output_table_size1)) # Test cases for user defined class level samples, with output table size - self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size1, self.input_class_level_counts1, self.output_table_size2)) - self.assertEqual({'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size2, self.input_class_level_counts1, self.output_table_size2)) - self.assertEqual({'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str0, self.user_specified_class_size3, self.input_class_level_counts1, self.output_table_size2)) # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any output table size - self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str1, self.user_specified_class_size0, self.input_class_level_counts1, self.output_table_size1)) - self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}, + self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': (30, OVERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str2, self.user_specified_class_size0, self.input_class_level_counts1, self.output_table_size1)) - self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str3, self.user_specified_class_size0, self.input_class_level_counts1, self.output_table_size1)) # Test cases for UNIFORM with output table size - self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}, + self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}}, _get_target_level_counts(self.sampling_strategy_str1, self.user_specified_class_size0, self.input_class_level_counts1, @@ -800,11 +924,11 @@ class UtilitiesTestCase(unittest.TestCase): target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)} target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 'c': (25, NOSAMPLE)} self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}), - _get_sampling_strategy_specific_dict(target_level_counts_1)) + _get_sampling_strategy_counts(target_level_counts_1)) self.assertEqual(({'b': 25}, {'a': 25}, {}), - _get_sampling_strategy_specific_dict(target_level_counts_2)) + _get_sampling_strategy_counts(target_level_counts_2)) self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}), - _get_sampling_strategy_specific_dict(target_level_counts_3)) + _get_sampling_strategy_counts(target_level_counts_3)) if __name__ == '__main__': http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/balance_sample.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/balance_sample.sql_in b/src/ports/postgres/modules/sample/balance_sample.sql_in index cd70961..eea73aa 100644 --- a/src/ports/postgres/modules/sample/balance_sample.sql_in +++ b/src/ports/postgres/modules/sample/balance_sample.sql_in @@ -149,8 +149,11 @@ non-stratified, that is, the whole table is treated as a single group. @note -Current implementation does not support grouping_cols. It -will be added in an upcoming release. +The 'output_table_size' and the 'class_sizes' are defined for the whole table. +When grouping is used, these parameters are split evenly for each group. +Further, if a specific class value is specified in the 'class_sizes' parameter, +that particular class value should be present in each group. If not, an error +will be thrown. </dd> <dt>with_replacement (optional)</dt> @@ -393,8 +396,7 @@ SELECT * FROM output_table ORDER BY mainhue, name; 7 | 20 | USA | 1 | 4 | 9363 | 231 | 1 | 3 | white (8 rows) </pre> -In the case of bootstrapping, we may want to undersample with replacement, -so we set the 'with_replacement' parameter to TRUE: +We may also want to undersample with replacement, so we set the 'with_replacement' parameter to TRUE: <pre class="syntax"> DROP TABLE IF EXISTS output_table; SELECT madlib.balance_sample( @@ -543,6 +545,95 @@ SELECT * FROM output_table ORDER BY mainhue, name; (25 rows) </pre> +-# To perform the balance sampling for independent groups, use the 'grouping_cols' +parameter. Note below that each group (zone) has a different count of the +classes (mainhue), with some groups not containing some class values. +<pre class="syntax"> +DROP TABLE IF EXISTS output_table; +SELECT madlib.balance_sample( + 'flags', -- Source table + 'output_table', -- Output table + 'mainhue', -- Class column + NULL, -- Uniform + NULL, -- Output table size + 'zone' -- Grouping by zone +); +SELECT * FROM output_table ORDER BY zone, mainhue; +</pre> +<pre class="result"> + __madlib_id__ | id | name | landmass | zone | area | population | language | colours | mainhue +---------------+----+-------------+----------+------+------+------------+----------+---------+--------- + 6 | 8 | Greece | 3 | 1 | 132 | 10 | 6 | 2 | blue + 5 | 8 | Greece | 3 | 1 | 132 | 10 | 6 | 2 | blue + 8 | 17 | Sweden | 3 | 1 | 450 | 8 | 6 | 2 | blue + 7 | 8 | Greece | 3 | 1 | 132 | 10 | 6 | 2 | blue + 2 | 7 | Denmark | 3 | 1 | 43 | 5 | 6 | 2 | red + 1 | 6 | China | 5 | 1 | 9561 | 1008 | 7 | 2 | red + 4 | 12 | Luxembourg | 3 | 1 | 3 | 0 | 4 | 3 | red + 3 | 18 | Switzerland | 3 | 1 | 41 | 6 | 4 | 2 | red + 1 | 2 | Australia | 6 | 2 | 7690 | 15 | 1 | 3 | blue + 1 | 1 | Argentina | 2 | 3 | 2777 | 28 | 2 | 2 | blue + 2 | 4 | Brazil | 2 | 3 | 8512 | 119 | 6 | 4 | green + 6 | 9 | Guatemala | 1 | 4 | 109 | 8 | 2 | 2 | blue + 5 | 9 | Guatemala | 1 | 4 | 109 | 8 | 2 | 2 | blue + 4 | 9 | Guatemala | 1 | 4 | 109 | 8 | 2 | 2 | blue + 12 | 13 | Mexico | 1 | 4 | 1973 | 77 | 2 | 4 | green + 10 | 13 | Mexico | 1 | 4 | 1973 | 77 | 2 | 4 | green + 11 | 13 | Mexico | 1 | 4 | 1973 | 77 | 2 | 4 | green + 1 | 19 | UK | 3 | 4 | 245 | 56 | 1 | 3 | red + 3 | 5 | Canada | 1 | 4 | 9976 | 24 | 1 | 2 | red + 2 | 15 | Portugal | 3 | 4 | 92 | 10 | 6 | 5 | red + 8 | 20 | USA | 1 | 4 | 9363 | 231 | 1 | 3 | white + 7 | 20 | USA | 1 | 4 | 9363 | 231 | 1 | 3 | white + 9 | 10 | Ireland | 3 | 4 | 70 | 3 | 1 | 3 | white +(23 rows) +</pre> + +-# Grouping can be used with class size specification as well. Note below that +'blue=<Integer>' is the only valid class value since 'blue' is the only class +value that is present in each group. Further, 'blue=8' will be split between the +four groups, resulting in two blue rows for each group. +<pre class="syntax"> +DROP TABLE IF EXISTS output_table; +SELECT madlib.balance_sample( + 'flags', -- Source table + 'output_table', -- Output table + 'mainhue', -- Class column + 'blue=8', -- Specified class value size. Rest of the values are outputed as is. + NULL, -- Output table size + 'zone' -- Group by zone +); +SELECT * FROM output_table ORDER BY zone, mainhue; +</pre> +<pre class="result"> + __madlib_id__ | id | name | landmass | zone | area | population | language | colours | mainhue +---------------+----+-------------+----------+------+------+------------+----------+---------+--------- + 2 | 17 | Sweden | 3 | 1 | 450 | 8 | 6 | 2 | blue + 1 | 8 | Greece | 3 | 1 | 132 | 10 | 6 | 2 | blue + 3 | 3 | Austria | 3 | 1 | 84 | 8 | 4 | 2 | red + 5 | 7 | Denmark | 3 | 1 | 43 | 5 | 6 | 2 | red + 4 | 6 | China | 5 | 1 | 9561 | 1008 | 7 | 2 | red + 8 | 18 | Switzerland | 3 | 1 | 41 | 6 | 4 | 2 | red + 7 | 14 | Norway | 3 | 1 | 324 | 4 | 6 | 3 | red + 6 | 12 | Luxembourg | 3 | 1 | 3 | 0 | 4 | 3 | red + 1 | 2 | Australia | 6 | 2 | 7690 | 15 | 1 | 3 | blue + 2 | 2 | Australia | 6 | 2 | 7690 | 15 | 1 | 3 | blue + 1 | 1 | Argentina | 2 | 3 | 2777 | 28 | 2 | 2 | blue + 2 | 1 | Argentina | 2 | 3 | 2777 | 28 | 2 | 2 | blue + 3 | 4 | Brazil | 2 | 3 | 8512 | 119 | 6 | 4 | green + 2 | 9 | Guatemala | 1 | 4 | 109 | 8 | 2 | 2 | blue + 1 | 9 | Guatemala | 1 | 4 | 109 | 8 | 2 | 2 | blue + 5 | 11 | Jamaica | 1 | 4 | 11 | 2 | 1 | 3 | green + 6 | 13 | Mexico | 1 | 4 | 1973 | 77 | 2 | 4 | green + 3 | 5 | Canada | 1 | 4 | 9976 | 24 | 1 | 2 | red + 7 | 15 | Portugal | 3 | 4 | 92 | 10 | 6 | 5 | red + 8 | 16 | Spain | 3 | 4 | 505 | 38 | 2 | 2 | red + 9 | 19 | UK | 3 | 4 | 245 | 56 | 1 | 3 | red + 10 | 20 | USA | 1 | 4 | 9363 | 231 | 1 | 3 | white + 4 | 10 | Ireland | 3 | 4 | 70 | 3 | 1 | 3 | white +(23 rows) +</pre> + @anchor literature @par Literature http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/test/balance_sample.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/test/balance_sample.sql_in b/src/ports/postgres/modules/sample/test/balance_sample.sql_in index 2249fd6..e99b706 100644 --- a/src/ports/postgres/modules/sample/test/balance_sample.sql_in +++ b/src/ports/postgres/modules/sample/test/balance_sample.sql_in @@ -25,58 +25,66 @@ CREATE TABLE "TEST_s"( id1 INTEGER, id2 INTEGER, gr1 INTEGER, - gr2 INTEGER + gr2 INTEGER, + gr3 TEXT ); INSERT INTO "TEST_s" VALUES -(1,0,1,1), -(2,0,1,1), -(3,0,1,1), -(4,0,1,1), -(5,0,1,1), -(6,0,1,1), -(7,0,1,1), -(8,0,1,1), -(9,0,1,1), -(19,0,1,1), -(29,0,1,1), -(39,0,1,1), -(0,1,1,2), -(0,2,1,2), -(0,3,1,2), -(0,4,1,2), -(0,5,1,2), -(0,6,1,2), -(10,10,2,2), -(20,20,2,2), -(30,30,2,2), -(40,40,2,2), -(50,50,2,2), -(60,60,2,2), -(70,70,2,2), -(10,10,5,5), -(50,50,5,5), -(88,88,5,5), -(40,40,5,6), -(50,50,5,6), -(60,60,5,6), -(70,70,5,6), -(10,10,6,6), -(60,60,6,6), -(30,30,6,6), -(40,40,6,6), -(50,50,6,6), -(60,60,6,6), -(70,70,6,6), -(50,50,4,2), -(60,60,4,2), -(70,70,4,2), -(50,50,3,2), -(60,60,3,2), -(70,70,3,2), -(500,50,NULL,2), -(600,60,NULL,2), -(700,70,NULL,2) +(1,0,1,1,'a'), +(1,0,1,1,'b'), +(1,0,1,2,'b'), +(1,0,1,2,'c'), +(1,0,1,2,'d'), +(1,0,1,2,'e'), +(1,0,1,5,'c'), +(1,0,1,6,'c'), +(2,0,1,6,'d'), +(3,0,1,1,'a'), +(4,0,1,1,'a'), +(5,0,1,1,'a'), +(6,0,1,1,'a'), +(7,0,1,1,'b'), +(8,0,1,1,'b'), +(9,0,1,1,'b'), +(19,0,1,1,'b'), +(29,0,1,1,'b'), +(39,0,1,1,'b'), +(0,1,1,2,'b'), +(0,2,1,2,'b'), +(0,3,1,2,'b'), +(0,4,1,2,'b'), +(0,5,1,2,'b'), +(0,6,1,2,'b'), +(10,10,2,2,'c'), +(20,20,2,2,'c'), +(30,30,2,2,'c'), +(40,40,2,2,'c'), +(50,50,2,2,'c'), +(60,60,2,2,'c'), +(70,70,2,2,'c'), +(10,10,5,5,'c'), +(50,50,5,5,'c'), +(88,88,5,5,'c'), +(40,40,5,6,'c'), +(50,50,5,6,'c'), +(60,60,5,6,'c'), +(70,70,5,6,'c'), +(10,10,6,6,'c'), +(60,60,6,6,'c'), +(30,30,6,6,'d'), +(40,40,6,6,'d'), +(50,50,6,6,'d'), +(60,60,6,6,'d'), +(70,70,6,6,'d'), +(50,50,4,2,'d'), +(60,60,4,2,'d'), +(70,70,4,2,'d'), +(50,50,3,2,'d'), +(60,60,3,2,'d'), +(70,70,3,2,'d'), +(500,50,NULL,2,'e'), +(600,60,NULL,2,'e'), +(700,70,NULL,2,'e') ; -- SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1; @@ -92,14 +100,6 @@ INSERT INTO "TEST_s" VALUES -- (6 rows) SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1; - ---- Test for random undersampling without replacement -DROP TABLE IF EXISTS out_s; -SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', ' undersample', NULL, NULL, FALSE); -SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1; -SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling gr1') FROM - (SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1) AS foo WHERE foo.c != 3; - -- --- Test for random undersampling with replacement DROP TABLE IF EXISTS out_sr2; SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample ', NULL, NULL, TRUE, TRUE); @@ -108,19 +108,39 @@ SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1; SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling with replacement on gr1') FROM (SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1) AS foo WHERE foo.c != 3; +--- Test for random undersampling without replacement +DROP TABLE IF EXISTS out_s; +SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'undersample', NULL, 'gr2, gr3', FALSE); +SELECT * FROM out_s; +SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3; +DROP TABLE IF EXISTS out_s; +SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'oversample', NULL, 'gr2, gr3', FALSE); +SELECT * FROM out_s; +SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3; +DROP TABLE IF EXISTS out_s; +SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'uniform', NULL, 'gr2, gr3', FALSE); +SELECT * FROM out_s; +SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3; +DROP TABLE IF EXISTS out_s; +SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', '1=3', NULL, 'gr2, gr3', FALSE); +SELECT * FROM out_s; +SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3; +-- SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling gr1') FROM +-- (SELECT gr1, gr2, count(*) AS c FROM out_s GROUP BY gr1, gr2) AS foo WHERE foo.c != 3; + -- --- Test for random oversampling DROP TABLE IF EXISTS out_or3; SELECT balance_sample('"TEST_s"', 'out_or3', 'gr1', ' oVEr ', NULL, NULL); SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1; SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling') FROM - (SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1) AS foo WHERE foo.c != 18; + (SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1) AS foo WHERE foo.c != 25; --- UNIFORM sampling DROP TABLE IF EXISTS out_cd2; SELECT balance_sample('"TEST_s"', 'out_cd2', 'gr1', 'Uniform', NULL, NULL); SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1; SELECT assert(count(*) = 0, 'Wrong number of samples on uniform sampling for gr1') FROM - (SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1) AS foo WHERE foo.c != 8; + (SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1) AS foo WHERE foo.c != 9; --- Default sampling should be uniform DROP TABLE IF EXISTS out_cd3; @@ -131,27 +151,27 @@ SELECT assert(count(*) = 0, 'Wrong number of samples on uniform sampling for gr1 --- Only one class size is specified DROP TABLE IF EXISTS out_cd4; -SELECT balance_sample('"TEST_s"', 'out_cd4', 'gr1', ' 2 =10', NULL, NULL, TRUE); -SELECT assert(count(*) = 48, 'Wrong number of samples on oversampling with comma-delimited list') from out_cd4; -SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd4 where gr1 = 2) AS foo WHERE foo.c != 10; -SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd4 where gr1 = 3) AS foo WHERE foo.c != 3; -SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd4 where gr1 = 1) AS foo WHERE foo.c != 18; +SELECT balance_sample('"TEST_s"', 'out_cd4', 'gr1', '2=10', NULL, NULL, TRUE); +SELECT gr1, count(*) AS c FROM out_cd4 GROUP BY gr1; +SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with specified class sizes') from +out_cd4 where gr1 = 2; +SELECT assert(count(*) = 3, 'Wrong number of samples on sampling with specified class sizes') from +out_cd4 where gr1 = 3; +SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with specified class sizes') from +out_cd4 where gr1 = 1; --- Multiple class sizes with comma delimited string DROP TABLE IF EXISTS out_cd5; SELECT balance_sample('"TEST_s"', 'out_cd5', 'gr1', '2= 10, 3=6, 1 = 10', 100, NULL); select gr1, count(*) from out_cd5 group by gr1; SELECT assert(count(*) >= 100, 'Wrong number of samples on sampling with comma-delimited list') from out_cd5; -SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd5 where gr1 = 2) AS foo WHERE foo.c != 10; -SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd5 where gr1 = 4) AS foo WHERE foo.c != 25; -SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd5 where gr1 = 1) AS foo WHERE foo.c != 10; -SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd5 where gr1 = 5) AS foo WHERE foo.c != 25; -SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with comma-delimited list') from - (SELECT count(*) AS c FROM out_cd5 where gr1 = 3) AS foo WHERE foo.c != 6; +SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with comma-delimited list') from +out_cd5 where gr1 = 2; +SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with comma-delimited list') from +out_cd5 where gr1 = 4; +SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with comma-delimited list') from +out_cd5 where gr1 = 1; +SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with comma-delimited list') from +out_cd5 where gr1 = 5; +SELECT assert(count(*) = 6, 'Wrong number of samples on sampling with comma-delimited list') from +out_cd5 where gr1 = 3; http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/utilities.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index 8beb701..135404f 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -686,6 +686,7 @@ def strip_end_quotes(input_str, quote='"'): return input_str # ------------------------------------------------------------------------------ + def _grp_null_checks(grp_list): """ Helper function for generating NULL checks for grouping columns @@ -695,6 +696,7 @@ def _grp_null_checks(grp_list): """ return ' AND '.join([" {i} IS NOT NULL ".format(**locals()) for i in grp_list]) +# ------------------------------------------------------------------------------ def _check_groups(tbl1, tbl2, grp_list): @@ -708,6 +710,8 @@ def _check_groups(tbl1, tbl2, grp_list): return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals()) for i in grp_list]) +# ------------------------------------------------------------------------------ + def get_filtered_cols_subquery_str(include_from_table, exclude_from_table, filter_cols_list): @@ -727,6 +731,8 @@ def get_filtered_cols_subquery_str(include_from_table, exclude_from_table, (SELECT {cols} FROM {exclude_from_table}) """.format(**locals()) +# ------------------------------------------------------------------------------ + def get_table_qualified_col_str(tbl_name, col_list): """ @@ -746,9 +752,7 @@ def get_grouping_col_str(schema_madlib, module_name, reserved_cols, _string_to_array_with_quotes(grouping_col), module_name) intersect = frozenset( - _string_to_array(grouping_col)).intersection( - frozenset( - (reserved_cols))) + _string_to_array(grouping_col)).intersection(frozenset(reserved_cols)) _assert(len(intersect) == 0, "{0} error: Conflicting grouping column name.\n" "Some predefined keyword(s) ({1}) are not allowed " @@ -767,6 +771,20 @@ def get_grouping_col_str(schema_madlib, module_name, reserved_cols, return grouping_str, grouping_col # ------------------------------------------------------------------------------ + +def collate_plpy_result(plpy_result_rows): + if not plpy_result_rows: + return {} + else: + all_keys = plpy_result_rows[0].keys() + result = collections.defaultdict(list) + for each_row in plpy_result_rows: + for each_key in all_keys: + result[each_key].append(each_row[each_key]) + return result +# ------------------------------------------------------------------------------ + + import unittest @@ -827,6 +845,22 @@ class UtilitiesTestCase(unittest.TestCase): self.assertEqual(['"a^5,6"', 'b', 'c'], split_quoted_delimited_str('"a^5,6", b, c', quote='"')) self.assertEqual(['"A""^5,6"', 'b', 'c'], split_quoted_delimited_str('"A""^5,6", b, c', quote='"')) + def test_collate_plpy_result(self): + plpy_result1 = [{'classes': '4', 'class_count': 3}, + {'classes': '1', 'class_count': 18}, + {'classes': '5', 'class_count': 7}, + {'classes': '3', 'class_count': 3}, + {'classes': '6', 'class_count': 7}, + {'classes': '2', 'class_count': 7}] + self.assertEqual(collate_plpy_result(plpy_result1), + {'classes': ['4', '1', '5', '3', '6', '2'], + 'class_count': [3, 18, 7, 3, 7, 7]}) + self.assertEqual(collate_plpy_result([]), {}) + self.assertEqual(collate_plpy_result([{'class': 'a'}, + {'class': 'b'}, + {'class': 'c'}]), + {'class': ['a', 'b', 'c']}) + if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/utilities.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in b/src/ports/postgres/modules/utilities/utilities.sql_in index 0ec864d..6fdb463 100644 --- a/src/ports/postgres/modules/utilities/utilities.sql_in +++ b/src/ports/postgres/modules/utilities/utilities.sql_in @@ -396,12 +396,13 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `'); CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.bool_to_text (BOOLEAN) RETURNS TEXT STRICT -LANGUAGE SQL AS ' +LANGUAGE SQL AS $$ SELECT CASE - WHEN $1 THEN ''t'' - ELSE ''f'' + WHEN $1 + THEN 'true' + ELSE 'false' END; -'m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `'); +$$m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `'); ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/validate_args.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index a2f43fd..ec1f410 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -21,6 +21,8 @@ object, as well as in any subsequent references to that object (e.g., in SELECT, DELETE, or UPDATE statements). """ +m4_changequote(`<!', `!>') + def unquote_ident(input_str): """ @@ -504,6 +506,7 @@ def explicit_bool_to_text(tbl, cols, schema_madlib): """ Patch madlib.bool_to_text for columns that are of type boolean. """ + m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!>, <!return cols!>, <!!>) col_to_type = dict(get_cols_and_types(tbl)) patched = [] for col in cols: