Repository: madlib
Updated Branches:
  refs/heads/master 3af2d703e -> 0bfcaf5c4


Balance sample: Add grouping support

JIRA: MADLIB-1168

This commit adds grouping support for balanced sampling.
Grouping is implemented as a loop over the existing logic,
with the sampling for each group run independently.

Additional changes:

- Ensure bool_to_text returns 'true' instead of 't' and 'false' instead
of 'f'.
- Use bool_to_text only for platforms that don't have bool-to-text cast.
- Add a Collate plpy results function

Closes #239


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0bfcaf5c
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0bfcaf5c
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0bfcaf5c

Branch: refs/heads/master
Commit: 0bfcaf5c4623224fdcabc75a72dff21c70bfb1f6
Parents: 3af2d70
Author: Rahul Iyer <ri...@apache.org>
Authored: Tue Mar 13 12:59:04 2018 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Mar 13 13:48:46 2018 -0700

----------------------------------------------------------------------
 src/ports/greenplum/cmake/GreenplumUtils.cmake  |   4 +
 src/ports/postgres/cmake/PostgreSQLUtils.cmake  |   1 +
 .../modules/sample/balance_sample.py_in         | 558 +++++++++++--------
 .../modules/sample/balance_sample.sql_in        |  99 +++-
 .../modules/sample/test/balance_sample.sql_in   | 174 +++---
 .../postgres/modules/utilities/utilities.py_in  |  40 +-
 .../postgres/modules/utilities/utilities.sql_in |   9 +-
 .../modules/utilities/validate_args.py_in       |   3 +
 8 files changed, 583 insertions(+), 305 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/greenplum/cmake/GreenplumUtils.cmake
----------------------------------------------------------------------
diff --git a/src/ports/greenplum/cmake/GreenplumUtils.cmake 
b/src/ports/greenplum/cmake/GreenplumUtils.cmake
index 9e3aa21..0fc1637 100644
--- a/src/ports/greenplum/cmake/GreenplumUtils.cmake
+++ b/src/ports/greenplum/cmake/GreenplumUtils.cmake
@@ -9,6 +9,10 @@ function(define_greenplum_features IN_VERSION OUT_FEATURES)
         list(APPEND ${OUT_FEATURES} __HAS_FUNCTION_PROPERTIES__)
     endif()
 
+    if(${IN_VERSION} VERSION_GREATER "4.3")
+        list(APPEND ${OUT_FEATURES} __HAS_BOOL_TO_TEXT_CAST__)
+    endif()
+
     # Pass values to caller
     set(${OUT_FEATURES} "${${OUT_FEATURES}}" PARENT_SCOPE)
 endfunction(define_greenplum_features)

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/cmake/PostgreSQLUtils.cmake
----------------------------------------------------------------------
diff --git a/src/ports/postgres/cmake/PostgreSQLUtils.cmake 
b/src/ports/postgres/cmake/PostgreSQLUtils.cmake
index 30962bb..11109a6 100644
--- a/src/ports/postgres/cmake/PostgreSQLUtils.cmake
+++ b/src/ports/postgres/cmake/PostgreSQLUtils.cmake
@@ -3,6 +3,7 @@
 function(define_postgresql_features IN_VERSION OUT_FEATURES)
     if(NOT ${IN_VERSION} VERSION_LESS "9.0")
         list(APPEND ${OUT_FEATURES} __HAS_ORDERED_AGGREGATES__)
+        list(APPEND ${OUT_FEATURES} __HAS_BOOL_TO_TEXT_CAST__)
     endif()
 
     # Pass values to caller

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/balance_sample.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/balance_sample.py_in 
b/src/ports/postgres/modules/sample/balance_sample.py_in
index a32e8e7..28cd11c 100644
--- a/src/ports/postgres/modules/sample/balance_sample.py_in
+++ b/src/ports/postgres/modules/sample/balance_sample.py_in
@@ -17,9 +17,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-m4_changequote(`<!', `!>')
+# m4_changequote(`<!', `!>')
 
 import math
+from collections import defaultdict
 
 if __name__ != "__main__":
     import plpy
@@ -27,7 +28,10 @@ if __name__ != "__main__":
     from utilities.utilities import _assert
     from utilities.utilities import extract_keyvalue_params
     from utilities.utilities import unique_string
+    from utilities.utilities import collate_plpy_result
+    from utilities.utilities import get_grouping_col_str
     from utilities.validate_args import columns_exist_in_table
+    from utilities.validate_args import explicit_bool_to_text
     from utilities.validate_args import get_cols
     from utilities.validate_args import table_exists
     from utilities.validate_args import table_is_empty
@@ -58,28 +62,65 @@ NOSAMPLE = 'nosample'
 NEW_ID_COLUMN = '__madlib_id__'
 NULL_IDENTIFIER = '__madlib_null_id__'
 
-def _get_level_frequency_distribution(source_table, class_col):
-    """ Returns a dict containing the number of rows associated with each class
+
+def _get_level_frequency_distribution(source_table, class_col,
+                                      grp_by_cols=None):
+    """ Count the number of rows for each class, partitioned by the grp_by_cols
+
+        Returns a dict containing the number of rows associated with each class
         level. Each class level count is converted to a string using ::text.
         None is a valid key in this dict, capturing NULL value in the database.
     """
+    if grp_by_cols and grp_by_cols.lower() != 'null':
+        is_grouping = True
+        grp_by_cols_comma = grp_by_cols + ', '
+        array_grp_by_cols_comma = "ARRAY[{0}]".format(grp_by_cols) + " AS 
group_values, "
+    else:
+        is_grouping = False
+        grp_by_cols_comma = array_grp_by_cols_comma = ""
+
+    # In below query, the inner query groups the data using grp_by_cols + 
classes
+    # and obtains the count for each combination. The outer query then groups
+    # again by the grp_by_cols to collect the classes and counts in an array.
     query_result = plpy.execute("""
-                    SELECT {class_col}::text AS classes,
-                           count(*) AS class_count
-                    FROM {source_table}
-                    GROUP BY {class_col}
-                 """.format(**locals()))
-    actual_level_counts = {}
+        SELECT
+            -- For each group get the classes and their rows counts
+            {grp_identifier} as group_values,
+            array_agg(classes) as classes,
+            array_agg(class_count) as class_count
+        FROM(
+            -- for each group and class combination present in source table
+            -- get the count of rows for that combination
+            SELECT
+                {array_grp_by_cols_comma}
+                ({class_col})::TEXT AS classes,
+                count(*) AS class_count
+            FROM {source_table}
+            GROUP BY {grp_by_cols_comma} ({class_col})
+        ) q
+        {meta_grp_by}
+     """.format(grp_identifier="group_values" if is_grouping else "NULL",
+                meta_grp_by="GROUP BY group_values" if is_grouping else "",
+                **locals()))
+    if (len(query_result) > 1) != is_grouping:
+        # if is_grouping then query_result should have more than 1 row
+        # if not is_grouping then query_result should have only 1 row
+        raise RuntimeError("Balance sample: Error during frequency level 
distribution")
+
+    actual_grp_level_counts = {}
     for each_row in query_result:
-        level = each_row['classes']
-        if level:
-            level = level.strip()
-        actual_level_counts[level] = each_row['class_count']
-    return actual_level_counts
+        # group_values is a list for each row; convert it to a tuple to use as
+        # key in a dictionary
+        grp = tuple(each_row['group_values']) if is_grouping else None
+        grp_levels, grp_counts = each_row['classes'], each_row['class_count']
+        actual_grp_level_counts[grp] = dict(zip(grp_levels, grp_counts))
+    return actual_grp_level_counts
+# 
------------------------------------------------------------------------------
 
 
-def _validate_and_get_sampling_strategy(sampling_strategy_str, 
output_table_size,
-                            supported_strategies=None, default=UNIFORM):
+def _validate_and_get_sampling_strategy(sampling_strategy_str,
+                                        output_table_size,
+                                        default=UNIFORM):
     """ Returns the sampling strategy based on the class_sizes input param.
         @param sampling_strategy_str The sampling strategy specified by the
                                          user (class_sizes param)
@@ -94,8 +135,7 @@ def 
_validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size
             # common prefix substring
             plpy.error("Sample: Invalid class_sizes parameter")
 
-        if not supported_strategies:
-            supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+        supported_strategies = (UNIFORM, UNDERSAMPLE, OVERSAMPLE)
         try:
             # allow user to specify a prefix substring of
             # supported strategies.
@@ -104,14 +144,16 @@ def 
_validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size
         except StopIteration:
             # next() returns a StopIteration if no element found
             plpy.error("Sample: Invalid class_sizes parameter: "
-                       "{0}. Supported class_size parameters are ({1})"
-                       .format(sampling_strategy_str, 
','.join(sorted(supported_strategies))))
+                       "{0}. Supported class_size parameters are ({1})".
+                       format(sampling_strategy_str,
+                              ','.join(sorted(supported_strategies))))
 
     _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, 
OVERSAMPLE) or
             (sampling_strategy_str.find('=') > 0),
             "Sample: Invalid class_sizes parameter: "
-            "{0}. Supported class_size parameters are ({1})"
-            .format(sampling_strategy_str, 
','.join(sorted(supported_strategies))))
+            "{0}. Supported class_size parameters are ({1})".
+            format(sampling_strategy_str,
+                   ','.join(sorted(supported_strategies))))
 
     _assert(not(sampling_strategy_str.lower() == 'oversample' and 
output_table_size),
             "Sample: Cannot set output_table_size with oversampling.")
@@ -147,47 +189,75 @@ def _choose_strategy(actual_count, desired_count):
         return UNDERSAMPLE
 # -------------------------------------------------------------------------
 
-def _get_desired_target_level_counts(desired_level_counts, actual_level_counts,
-                                     output_table_size):
-    """ Returns the target level counts for each class, based on the user
-        defined class_size string. The strategy of (under)oversampling for each
-        class level is chosen based on the desired number of counts for a 
level,
-        and the actual number of counts already present in the input table.
 
+def _get_desired_target_level_counts(desired_level_counts,
+                                     actual_grp_level_counts,
+                                     output_table_size):
+    """ Return the target counts for each group and each class level in the 
group
+
+        This function is specifically used when the user has provided the
+        targets  for all (or a subset) of the levels. The strategy of either
+        under or oversampling for each class level is chosen based on the
+        desired number of counts for a level, and the actual number of counts
+        already present in the input table. This calculation is performed for
+        each group (or once if no grouping is used).
+
+        @returns: dict: The key of the dictionary is the group value and the
+                        value is another dictionary. The inner dictionary gives
+                        the target count for each level present in the group.
     """
-    target_level_counts = {}
-    for each_level, desired_count in desired_level_counts.items():
-        sample_strategy = _choose_strategy(actual_level_counts[each_level],
-                                           desired_count)
-        target_level_counts[each_level] = (desired_count, sample_strategy)
-
-    remaining_levels = (set(actual_level_counts.keys()) -
-                        set(desired_level_counts.keys()))
-    #   if 'output_table_size' = NULL, unspecified level counts remain as is
-    #   if 'output_table_size' = <Integer>, divide remaining row count
-    #                             uniformly among unspecified level counts
-    if output_table_size:
-        # Uniformly distribute across the remaining class levels
-        remaining_rows = output_table_size - sum(desired_level_counts.values())
-        if remaining_rows > 0:
-            rows_per_level = math.ceil(float(remaining_rows) /
-                                       len(remaining_levels))
+    target_grp_level_counts = defaultdict(dict)
+    n_grp_values = len(actual_grp_level_counts)
+    for each_grp, actual_level_counts in actual_grp_level_counts.items():
+        for each_level, desired_count in desired_level_counts.items():
+            # actual_grp_level_counts are group specific, whereas
+            # desired_level_counts are evenly split among all groups
+            try:
+                per_grp_desired_count = math.ceil(float(desired_count) / 
n_grp_values)
+                sample_strategy = 
_choose_strategy(actual_level_counts[each_level],
+                                                   per_grp_desired_count)
+                target_grp_level_counts[each_grp][each_level] = (
+                    per_grp_desired_count, sample_strategy)
+            except KeyError:
+                plpy.error("Balance sample: Desired class level ({0}) not 
present in the "
+                           "data for each group.".format(each_level))
+
+        # desired levels could contain just a subset of all levels. For the 
remaining
+        # levels in actual_level_counts, compute the desired counts
+        remaining_levels = (set(actual_level_counts.keys()) -
+                            set(desired_level_counts.keys()))
+
+        #   if 'output_table_size' = NULL, remaining level counts remain as is
+        #   if 'output_table_size' = <Integer>, divide remaining count
+        #                             uniformly among reamining levels
+        if output_table_size:
+            # output_table_size is for the whole table and should be split 
evenly
+            # between the groups
+            remaining_rows = math.ceil(float(output_table_size -
+                                             
sum(desired_level_counts.values())) /
+                                       n_grp_values)
+            if remaining_rows > 0:
+                # Uniformly distribute the remaining class levels
+                rows_per_level = math.ceil(float(remaining_rows) /
+                                           len(remaining_levels))
+                for each_level in remaining_levels:
+                    sample_strategy = _choose_strategy(
+                        actual_level_counts[each_level], rows_per_level)
+                    target_grp_level_counts[each_grp][each_level] = (
+                        rows_per_level, sample_strategy)
+        else:
+            # When output_table_size is unspecified, rows from the input table
+            # are sampled as is for remaining class levels. This is called as 
the
+            # NOSAMPLE strategy.
             for each_level in remaining_levels:
-                sample_strategy = _choose_strategy(
-                    actual_level_counts[each_level], rows_per_level)
-                target_level_counts[each_level] = (rows_per_level,
-                                                   sample_strategy)
-    else:
-        # When output_table_size is unspecified, rows from the input table
-        # are sampled as is for remaining class levels. This is same as the
-        # NOSAMPLE strategy.
-        for each_level in remaining_levels:
-            target_level_counts[each_level] = (actual_level_counts[each_level],
-                                                    NOSAMPLE)
-    return target_level_counts
+                target_grp_level_counts[each_grp][each_level] = (
+                    actual_level_counts[each_level], NOSAMPLE)
+    return target_grp_level_counts
 # -------------------------------------------------------------------------
 
-def _get_supported_target_level_counts(sampling_strategy_str, 
actual_level_counts,
+
+def _get_supported_target_level_counts(sampling_strategy_str,
+                                       actual_grp_level_counts,
                                        output_table_size):
     """ Returns the target level counts for all levels when the class_size 
param
         is one of [uniform, undersample, oversample]. The strategy of
@@ -196,33 +266,40 @@ def 
_get_supported_target_level_counts(sampling_strategy_str, actual_level_count
         already present in the input table.
 
     """
-    def ceil_of_mean(numbers):
-        return math.ceil(float(sum(numbers)) / max(len(numbers), 1))
-
-    target_level_counts = {}
-    # UNIFORM: Ensure all level counts are same (size determined by 
output_table_size)
-    # UNDERSAMPLE: Ensure all level counts are same as the minimum count
-    # OVERSAMPLE: Ensure all level counts are same as the maximum count
-    size_function = {UNDERSAMPLE: min,
-                     OVERSAMPLE: max,
-                     UNIFORM: ceil_of_mean
-                     }[sampling_strategy_str]
-    if sampling_strategy_str == UNIFORM and output_table_size:
-        # Ignore actual counts for computing target sizes
-        # if output_table_size is specified
-        target_size_per_level = math.ceil(float(output_table_size) /
-                                          len(actual_level_counts))
-    else:
-        target_size_per_level = size_function(actual_level_counts.values())
-    for each_level, actual_count in actual_level_counts.items():
-        sample_strategy = _choose_strategy(actual_count, target_size_per_level)
-        target_level_counts[each_level] = (target_size_per_level,
-                                           sample_strategy)
-    return target_level_counts
+    target_grp_level_counts = defaultdict(dict)
+    n_grp_values = len(actual_grp_level_counts)
+    for each_grp, actual_level_counts in actual_grp_level_counts.items():
+        if sampling_strategy_str == UNIFORM:
+            # UNIFORM: Ensure all level counts are same
+            if output_table_size:
+                # Ignore actual counts for computing target sizes
+                # if output_table_size is specified
+                total_ = float(output_table_size) / n_grp_values
+            else:
+                total_ = sum(actual_level_counts.values())
+            target_size_per_level = math.ceil(float(total_) /
+                                              len(actual_level_counts))
+        else:
+            # UNDERSAMPLE: Ensure all level counts are same as the minimum 
count
+            # OVERSAMPLE: Ensure all level counts are same as the maximum count
+            if sampling_strategy_str == UNDERSAMPLE:
+                target_size_per_level = min(actual_level_counts.values())
+            elif sampling_strategy_str == OVERSAMPLE:
+                target_size_per_level = max(actual_level_counts.values())
+            else:
+                raise RuntimeError("Balance sample: Invalid "
+                                   "sampling_strategy_str encountered")
+
+        for each_level, actual_count in actual_level_counts.items():
+            sample_strategy = _choose_strategy(actual_count, 
target_size_per_level)
+            target_grp_level_counts[each_grp][each_level] = (
+                target_size_per_level, sample_strategy)
+    return target_grp_level_counts
 # -------------------------------------------------------------------------
 
+
 def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
-                             actual_level_counts, output_table_size):
+                             actual_grp_level_counts, output_table_size):
     """
     @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, 
None].
                                This is 'None' only if this is user-defined, 
i.e.,
@@ -233,8 +310,9 @@ def _get_target_level_counts(sampling_strategy_str, 
desired_level_counts,
                                  then contain the class levels and the
                                  corresponding number of rows specified by
                                  the user.
-    @param actual_level_counts: Dict of various class levels and number of rows
-                                  in each of them in the input table
+    @param actual_grp_level_counts: Dictionary that provides for each group the
+                                  the count of number of rows for each class
+                                  present in the group
     @param output_table_size: Size of the desired output table (NULL or 
Integer)
 
     @returns:
@@ -244,60 +322,64 @@ def _get_target_level_counts(sampling_strategy_str, 
desired_level_counts,
 
     if not sampling_strategy_str:
         # This case implies user has provided a desired count for one or more
-        # levels. Counts for the rest of the levels depend on 
'output_table_size'.
-        target_level_counts = 
_get_desired_target_level_counts(desired_level_counts,
-                                                               
actual_level_counts,
-                                                               
output_table_size)
+        # levels. Counts for rest of the levels depend on 'output_table_size'.
+        target_level_counts = _get_desired_target_level_counts(
+            desired_level_counts, actual_grp_level_counts, output_table_size)
     else:
-        # This case imples the user has chosen one of [uniform, undersample,
-        # oversample] for the class_size param.
-        target_level_counts = 
_get_supported_target_level_counts(sampling_strategy_str,
-                                                                 
actual_level_counts,
-                                                                 
output_table_size)
+        # This case imples the user has chosen one of
+        # [uniform, undersample, oversample] for the class_size parameter.
+        target_level_counts = _get_supported_target_level_counts(
+            sampling_strategy_str, actual_grp_level_counts, output_table_size)
     return target_level_counts
-
 # -------------------------------------------------------------------------
 
 
-def _get_sampling_strategy_specific_dict(target_class_sizes):
+def _get_sampling_strategy_counts(target_class_sizes):
     """ Return three dicts, one each for undersampling, oversampling, and
         nosampling. The dict contains the number of samples to be drawn for
         each class level.
     """
-    undersample_level_dict = {}
-    oversample_level_dict = {}
-    nosample_level_dict = {}
+    undersample_level_counts = {}
+    oversample_level_counts = {}
+    nosample_level_counts = {}
     for level, (count, strategy) in target_class_sizes.items():
         if strategy == UNDERSAMPLE:
-            undersample_level_dict[level] = count
+            undersample_level_counts[level] = count
         elif strategy == OVERSAMPLE:
-            oversample_level_dict[level] = count
+            oversample_level_counts[level] = count
         else:
-            nosample_level_dict[level] = count
-    return (undersample_level_dict, oversample_level_dict, nosample_level_dict)
+            nosample_level_counts[level] = count
+    return (undersample_level_counts, oversample_level_counts, 
nosample_level_counts)
 # 
------------------------------------------------------------------------------
 
 
-def _get_nosample_subquery(source_table, class_col, nosample_levels):
+def _get_nosample_subquery(source_table, class_col, nosample_levels,
+                           grp_dict=None):
     """ Return the subquery for fetching all rows as is from the input table
         for specific class levels.
     """
     if not nosample_levels:
         return ''
-    subquery = """
-                SELECT *
-                FROM {0}
-                WHERE {1} in ({2}) OR {1} IS NULL
-            """.format(source_table, class_col,
-                       ','.join(["'{0}'".format(level)
-                                for level in nosample_levels if level]))
-    return subquery
+    nosample_level_str = ','.join(["'{0}'".format(level)
+                                   for level in nosample_levels if level])
+    if grp_dict:
+        grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
+                                            for k, v in grp_dict.items())
+    else:
+        grp_filter = ''
+    return """
+        SELECT *
+        FROM {source_table}
+        WHERE ({class_col} in ({nosample_level_str}) OR
+               {class_col} IS NULL)
+              {grp_filter}
+        """.format(**locals())
 # 
------------------------------------------------------------------------------
 
 
-def _get_without_replacement_subquery(schema_madlib, source_table,
-                                      source_table_columns, class_col,
-                                      actual_level_counts, 
desired_level_counts):
+def _get_without_replacement_subquery(schema_madlib, source_table, class_col,
+                                      actual_level_counts, 
desired_level_counts,
+                                      grp_dict=None):
     """ Return the subquery for sampling without replacement for specific
         class levels.
     """
@@ -306,47 +388,50 @@ def _get_without_replacement_subquery(schema_madlib, 
source_table,
     class_col_tmp = unique_string(desp='class_col')
     row_number_col = unique_string(desp='row_number')
     desired_count_col = unique_string(desp='desired_count')
-
+    source_table_columns = ','.join(get_cols(source_table))
     null_value_string = "'{0}'".format(NULL_IDENTIFIER)
-
-    desired_level_counts_str = "VALUES " + \
-            ','.join("({0}, {1})".
-            format("'{0}'::text".format(k) if k else null_value_string, v)
-            for k, v in desired_level_counts.items())
+    desired_level_count_pairs = (','.join("({0}, {1})".
+                                 format("'{0}'::text".format(k) if k else 
null_value_string, v)
+                                 for k, v in desired_level_counts.items()))
+    desired_level_counts_str = "VALUES " + desired_level_count_pairs
     # Subquery q2 is used to figure out the number of rows to select for each
     # class level. That is used with row_number to order the input rows 
randomly,
-    # and, trim the results to the_desired_number_of_rows for each class level.
-    # q2:
+    # and trim the results to the desired number of rows for each class level.
+    # q1:
     #    The FROM clause contains information that can be used to figure out 
the
     #    number of rows to generate for a class level. The tuple basically
     #    contains: (class_level, the_desired_number_of_rows)
+    if grp_dict:
+        grp_filter = ' AND ' + ' AND '.join("{0} = '{1}'".format(k, v)
+                                            for k, v in grp_dict.items())
+    else:
+        grp_filter = ''
     subquery = """
             SELECT {source_table_columns}
-            FROM
-                (
+            FROM (
                     SELECT {source_table_columns},
                            row_number() OVER (PARTITION BY {class_col} ORDER 
BY random()) AS {row_number_col},
                            {desired_count_col}
-                    FROM
-                    (
+                    FROM (
                         SELECT {source_table_columns},
                                {desired_count_col}
                         FROM
                             {source_table} s,
                             ({desired_level_counts_str})
-                                q({class_col_tmp}, {desired_count_col})
-                        WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_level_val}')
+                                q1({class_col_tmp}, {desired_count_col})
+                        WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_identifier}')
+                              {grp_filter}
                     ) q2
                 ) q3
             WHERE {row_number_col} <= {desired_count_col}
-        """.format(null_level_val=NULL_IDENTIFIER, **locals())
+        """.format(null_identifier=NULL_IDENTIFIER, **locals())
     return subquery
 # 
------------------------------------------------------------------------------
 
 
-def _get_with_replacement_subquery(schema_madlib, source_table,
-                                   source_table_columns, class_col,
-                                   actual_level_counts, desired_level_counts):
+def _get_with_replacement_subquery(schema_madlib, source_table, class_col,
+                                   actual_level_counts, desired_level_counts,
+                                   grp_dict=None):
     """ Return the query for sampling with replacement for specific class
         levels. Always used for oversampling since oversampling will always 
need
         to use replacement. Used for under sampling only if with_replacement
@@ -355,6 +440,7 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
     if not desired_level_counts:
         return ''
 
+    source_table_columns = ','.join(get_cols(source_table))
     class_col_tmp = unique_string(desp='class_col_with_rep')
     desired_count_col = unique_string(desp='desired_count_with_rep')
     actual_count_col = unique_string(desp='actual_count')
@@ -362,12 +448,12 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
     q2_row_no = unique_string(desp='q2_row')
 
     null_value_string = "'{0}'".format(NULL_IDENTIFIER)
+    desired_and_actual_count = (','.join("({0}, {1}, {2})".
+                                format("'{0}'::text".format(k) if k else 
null_value_string,
+                                       v, actual_level_counts[k])
+                                for k, v in desired_level_counts.items()))
+    desired_and_actual_level_count_str = "VALUES " + desired_and_actual_count
 
-    desired_and_actual_level_counts = "VALUES " + \
-    ','.join("({0}, {1}, {2})".
-             format("'{0}'::text".format(k) if k else null_value_string,
-                v, actual_level_counts[k])
-             for k, v in desired_level_counts.items())
     # q1 and q2 are two sub queries we create to generate the required number 
of
     # rows per class level.
     # q1:
@@ -382,6 +468,11 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
     #   actual_rows_for_level_in_input_table.
     #
     # The WHERE clause is used to join the two subqueries to obtain the result.
+    if grp_dict:
+        grp_filter = 'WHERE ' + ' AND '.join("{0} = '{1}'".format(k, v)
+                                             for k, v in grp_dict.items())
+    else:
+        grp_filter = ''
     subquery = """
             SELECT {source_table_columns}
             FROM
@@ -391,7 +482,7 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
                          generate_series(1, {desired_count_col}::int) AS _i,
                          ((random()*({actual_count_col}-1)+1)::int) AS 
{q1_row_no}
                     FROM
-                        ({desired_and_actual_level_counts})
+                        ({desired_and_actual_level_count_str})
                             q({class_col_tmp}, {desired_count_col}, 
{actual_count_col})
                 ) q1,
                 (
@@ -400,6 +491,7 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
                         row_number() OVER(PARTITION BY {class_col}) AS 
{q2_row_no}
                     FROM
                          {source_table}
+                    {grp_filter}
                 ) q2
             WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_level_val}') AND
                   q1.{q1_row_no} = q2.{q2_row_no}
@@ -407,6 +499,7 @@ def _get_with_replacement_subquery(schema_madlib, 
source_table,
     return subquery
 # 
------------------------------------------------------------------------------
 
+
 def balance_sample(schema_madlib, source_table, output_table, class_col,
                    class_sizes, output_table_size, grouping_cols,
                    with_replacement, keep_null, **kwargs):
@@ -443,20 +536,18 @@ def balance_sample(schema_madlib, source_table, 
output_table, class_col,
 
         _validate_strs(source_table, output_table, class_col,
                        output_table_size, grouping_cols)
-        source_table_columns = ','.join(get_cols(source_table))
 
-        new_source_table = source_table
         # If keep_null=False, create a view of the input table ignoring NULL
         # values for class levels.
-        if not keep_null:
+        if keep_null:
+            new_source_table = source_table
+        else:
             new_source_table = unique_string(desp='source_table')
             plpy.execute("""
-                        CREATE VIEW {new_source_table} AS
-                        SELECT * FROM {source_table}
-                        WHERE {class_col} IS NOT NULL
-                    """.format(**locals()))
-        actual_level_counts = 
_get_level_frequency_distribution(new_source_table,
-                                                                class_col)
+                CREATE VIEW {new_source_table} AS
+                SELECT * FROM {source_table}
+                WHERE {class_col} IS NOT NULL
+            """.format(**locals()))
         # class_sizes can be of two forms:
         #   1. A string describing sampling strategy (as described in
         #       _validate_and_get_sampling_strategy).
@@ -468,22 +559,34 @@ def balance_sample(schema_madlib, source_table, 
output_table, class_col,
         parsed_class_sizes = extract_keyvalue_params(class_sizes,
                                                      allow_duplicates=False,
                                                      lower_case_names=False)
+
+        # GREENPLUM 4.3.X does not support bool-to-text cast which is relied
+        #  upon for class_col in multiple queries. The explicit_bool_to_text
+        #  function wraps the class_col with a MADlib function that provides 
the
+        #  cast just for those platforms that don't provide the cast. For
+        #  platforms that provide the cast, class_col is unchanged below.
+        class_col = explicit_bool_to_text(source_table,
+                                          [class_col],
+                                          schema_madlib)[0]
+        distinct_sql = ("SELECT DISTINCT ({0})::TEXT as levels FROM {1} ".
+                        format(class_col,
+                               source_table))
+        distinct_levels = 
collate_plpy_result(plpy.execute(distinct_sql))['levels']
         if not parsed_class_sizes:
-            sampling_strategy_str = 
_validate_and_get_sampling_strategy(class_sizes,
-                                        output_table_size)
+            sampling_strategy_str = _validate_and_get_sampling_strategy(
+                class_sizes, output_table_size)
         else:
             sampling_strategy_str = None
             try:
                 for each_level, each_class_size in parsed_class_sizes.items():
-                    _assert(each_level in actual_level_counts,
+                    _assert(each_level in distinct_levels,
                             "Sample: Invalid class value specified ({0})".
-                                       format(each_level))
+                            format(each_level))
                     each_class_size = int(each_class_size)
                     _assert(each_class_size >= 1,
                             "Sample: Class size has to be greater than zero")
                     parsed_class_sizes[each_level] = each_class_size
-
-            except TypeError:
+            except ValueError:
                 plpy.error("Sample: Invalid value for class_sizes ({0})".
                            format(class_sizes))
 
@@ -491,58 +594,82 @@ def balance_sample(schema_madlib, source_table, 
output_table, class_col,
         # the input table, class_sizes, and output_table_size params. This also
         # includes info about the resulting sampling strategy, i.e., one of
         # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
-        target_class_sizes = _get_target_level_counts(sampling_strategy_str,
-                                                      parsed_class_sizes,
-                                                      actual_level_counts,
-                                                      output_table_size)
-
-        undersample_level_dict, oversample_level_dict, nosample_level_dict = \
-            _get_sampling_strategy_specific_dict(target_class_sizes)
-
-        # Get subqueries for each sampling strategy, so that they can be used
-        # together in one big query.
-
-        # Subquery that will be used to get rows as is for those class levels
-        # that need no sampling.
-        nosample_subquery = _get_nosample_subquery(
-            new_source_table, class_col, nosample_level_dict.keys())
-        # Subquery that will be used to sample those class levels that
-        # have to be oversampled.
-        oversample_subquery = _get_with_replacement_subquery(
-            schema_madlib, new_source_table, source_table_columns, class_col,
-            actual_level_counts, oversample_level_dict)
-        # Subquery that will be used to sample those class levels that
-        # have to be undersampled. Undersampling supports both with and without
-        # replacement, so fetch the appropriate subquery.
-        if with_replacement:
-            undersample_subquery = _get_with_replacement_subquery(
-                schema_madlib, new_source_table, source_table_columns, 
class_col,
-                actual_level_counts, undersample_level_dict)
-        else:
-            undersample_subquery = _get_without_replacement_subquery(
-                schema_madlib, new_source_table, source_table_columns, 
class_col,
-                actual_level_counts, undersample_level_dict)
-
-        # Merge the three subqueries using a UNION ALL clause.
-        union_all_subquery = ' UNION ALL '.join(
-            ['({0})'.format(subquery)
-             for subquery in [undersample_subquery, oversample_subquery, 
nosample_subquery]
-             if subquery])
-
-        final_query = """
-                CREATE TABLE {output_table} AS
-                SELECT row_number() OVER() AS {new_col_name}, *
+        grp_col_str, grp_cols = get_grouping_col_str(
+            schema_madlib, 'Balance sample', [NEW_ID_COLUMN, class_col],
+            source_table, grouping_cols)
+        actual_grp_level_counts = _get_level_frequency_distribution(
+            new_source_table, class_col, grp_col_str)
+        target_grp_class_sizes = _get_target_level_counts(
+            sampling_strategy_str, parsed_class_sizes,
+            actual_grp_level_counts, output_table_size)
+
+        is_output_created = False
+        grp_cols_list = grp_col_str.split(',') if grp_cols else []
+
+        # for each group
+        for grp_vals, actual_level_counts in actual_grp_level_counts.items():
+            target_class_sizes = target_grp_class_sizes[grp_vals]
+
+            (undersample_level_counts,
+             oversample_level_counts,
+             nosample_level_counts) = 
_get_sampling_strategy_counts(target_class_sizes)
+
+            # grp_dict represents each grouping column and its value
+            # for current iteration
+            grp_dict = dict(zip(grp_cols_list, grp_vals)) if grp_cols else None
+
+            # Get subqueries for each sampling strategy and union together in
+            # one big query.
+
+            # NOSAMPLE are levels that are to be retained in output without 
sampling
+            nosample_subquery = _get_nosample_subquery(
+                new_source_table, class_col, nosample_level_counts.keys(), 
grp_dict)
+            # OVERSAMPLE are levels that to be sampled more than existing count
+            # (always with replacement)
+            oversample_subquery = _get_with_replacement_subquery(
+                schema_madlib, new_source_table, class_col,
+                actual_level_counts, oversample_level_counts, grp_dict)
+            # UNDERSAMPLE are levels that are to be sampled to less than
+            # existing count. Undersampling supports both with and without
+            # replacement.
+            if with_replacement:
+                undersample_subquery = _get_with_replacement_subquery(
+                    schema_madlib, new_source_table, class_col,
+                    actual_level_counts, undersample_level_counts, grp_dict)
+            else:
+                undersample_subquery = _get_without_replacement_subquery(
+                    schema_madlib, new_source_table, class_col,
+                    actual_level_counts, undersample_level_counts, grp_dict)
+
+            # Merge the three subqueries using a UNION ALL clause.
+            sampling_queries = (undersample_subquery, oversample_subquery, 
nosample_subquery)
+            union_all_subquery = ' UNION ALL '.join(
+                ['({0})'.format(subquery)
+                 for subquery in sampling_queries if subquery])
+
+            # Populate the output table
+            if is_output_created:
+                table_header = "INSERT INTO {0}".format(output_table)
+            else:
+                table_header = "CREATE TABLE {0} AS ".format(output_table)
+                is_output_created = True
+            final_query = """
+                {table_header}
+                SELECT row_number() OVER() AS {id_col_name}, *
                 FROM (
                     {union_all_subquery}
                 ) union_query
-            """.format(new_col_name=NEW_ID_COLUMN, **locals())
-        plpy.execute(final_query)
+            """.format(id_col_name=NEW_ID_COLUMN, **locals())
+            plpy.execute(final_query)
+        # end of grouping loop
+
         if not keep_null:
-            plpy.execute("DROP VIEW {0}".format(new_source_table))
+            plpy.execute("DROP VIEW IF EXISTS {0}".format(new_source_table))
+# 
------------------------------------------------------------------------------
 
 
-def _validate_strs(source_table, output_table, class_col, output_table_size,
-                    grouping_cols):
+def _validate_strs(source_table, output_table, class_col,
+                   output_table_size, grouping_cols):
     _assert(source_table and table_exists(source_table),
             "Sample: Source table ({source_table}) does not 
exist.".format(**locals()))
     _assert(not table_is_empty(source_table),
@@ -566,10 +693,7 @@ def _validate_strs(source_table, output_table, class_col, 
output_table_size,
     _assert((not output_table_size) or (output_table_size > 0),
             "Sample: Invalid output table size ({output_table_size}).".format(
             **locals()))
-
-    _assert(grouping_cols is None,
-            "grouping_cols is not supported at the moment."
-            .format(**locals()))
+# 
------------------------------------------------------------------------------
 
 
 def balance_sample_help(schema_madlib, message, **kwargs):
@@ -715,7 +839,7 @@ class UtilitiesTestCase(unittest.TestCase):
     """
 
     def setUp(self):
-        self.input_class_level_counts1 = {'a': 20, 'b': 30, 'c': 25}
+        self.input_class_level_counts1 = {None: {'a': 20, 'b': 30, 'c': 25}}
         self.level1a = 'a'
         self.level1a_cnt1 = 15
         self.level1a_cnt2 = 25
@@ -740,55 +864,55 @@ class UtilitiesTestCase(unittest.TestCase):
 
     def test__get_target_level_counts(self):
         # Test cases for user defined class level samples, without output 
table size
-        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, NOSAMPLE)},
+        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (25, NOSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size1,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
-        self.assertEqual({'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, NOSAMPLE)},
+        self.assertEqual({None: {'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (25, NOSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size2,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
-        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': 
(25, NOSAMPLE)},
+        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 
'c': (25, NOSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size3,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
         # Test cases for user defined class level samples, with output table 
size
-        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(10, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (10, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size1,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size2))
-        self.assertEqual({'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(18, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (18, UNDERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (18, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size2,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size2))
-        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': 
(15, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (15, 
UNDERSAMPLE), 'c': (15, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str0,
                                                   
self.user_specified_class_size3,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size2))
         # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any 
output table size
-        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (25, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str1,
                                                   
self.user_specified_class_size0,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
-        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': 
(30, OVERSAMPLE)},
+        self.assertEqual({None: {'a': (30, OVERSAMPLE), 'b': (30, 
UNDERSAMPLE), 'c': (30, OVERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str2,
                                                   
self.user_specified_class_size0,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
-        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': 
(20, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, 
UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str3,
                                                   
self.user_specified_class_size0,
                                                   
self.input_class_level_counts1,
                                                   self.output_table_size1))
         # Test cases for UNIFORM with output table size
-        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': 
(20, UNDERSAMPLE)},
+        self.assertEqual({None: {'a': (20, UNDERSAMPLE), 'b': (20, 
UNDERSAMPLE), 'c': (20, UNDERSAMPLE)}},
                          _get_target_level_counts(self.sampling_strategy_str1,
                                                   
self.user_specified_class_size0,
                                                   
self.input_class_level_counts1,
@@ -800,11 +924,11 @@ class UtilitiesTestCase(unittest.TestCase):
         target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)}
         target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 
'c': (25, NOSAMPLE)}
         self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
-                         
_get_sampling_strategy_specific_dict(target_level_counts_1))
+                         _get_sampling_strategy_counts(target_level_counts_1))
         self.assertEqual(({'b': 25}, {'a': 25}, {}),
-                         
_get_sampling_strategy_specific_dict(target_level_counts_2))
+                         _get_sampling_strategy_counts(target_level_counts_2))
         self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
-                         
_get_sampling_strategy_specific_dict(target_level_counts_3))
+                         _get_sampling_strategy_counts(target_level_counts_3))
 
 
 if __name__ == '__main__':

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/balance_sample.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/balance_sample.sql_in 
b/src/ports/postgres/modules/sample/balance_sample.sql_in
index cd70961..eea73aa 100644
--- a/src/ports/postgres/modules/sample/balance_sample.sql_in
+++ b/src/ports/postgres/modules/sample/balance_sample.sql_in
@@ -149,8 +149,11 @@ non-stratified, that is, the whole table is treated as a
 single group.
 
 @note
-Current implementation does not support grouping_cols.  It
-will be added in an upcoming release.
+The 'output_table_size' and the 'class_sizes' are defined for the whole table.
+When grouping is used, these parameters are split evenly for each group.
+Further, if a specific class value is specified in the 'class_sizes' parameter,
+that particular class value should be present in each group. If not, an error
+will be thrown.
 </dd>
 
 <dt>with_replacement  (optional)</dt>
@@ -393,8 +396,7 @@ SELECT * FROM output_table ORDER BY mainhue, name;
              7 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
 (8 rows)
 </pre>
-In the case of bootstrapping, we may want to undersample with replacement,
-so we set the 'with_replacement' parameter to TRUE:
+We may also want to undersample with replacement, so we set the 
'with_replacement' parameter to TRUE:
 <pre class="syntax">
 DROP TABLE IF EXISTS output_table;
 SELECT madlib.balance_sample(
@@ -543,6 +545,95 @@ SELECT * FROM output_table ORDER BY mainhue, name;
 (25 rows)
 </pre>
 
+-# To perform the balance sampling for independent groups, use the 
'grouping_cols'
+parameter. Note below that each group (zone) has a different count of the
+classes (mainhue), with some groups not containing some class values.
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+    'flags',          -- Source table
+    'output_table',   -- Output table
+    'mainhue',        -- Class column
+    NULL,             -- Uniform
+    NULL,             -- Output table size
+    'zone'            -- Grouping by zone
+);
+SELECT * FROM output_table ORDER BY zone, mainhue;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+             6 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+             5 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+             8 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+             7 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+             2 |  7 | Denmark     |        3 |    1 |   43 |          5 |      
  6 |       2 | red
+             1 |  6 | China       |        5 |    1 | 9561 |       1008 |      
  7 |       2 | red
+             4 | 12 | Luxembourg  |        3 |    1 |    3 |          0 |      
  4 |       3 | red
+             3 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+             1 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+             1 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+             2 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+             6 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+             5 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+             4 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+            12 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            10 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            11 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             1 | 19 | UK          |        3 |    4 |  245 |         56 |      
  1 |       3 | red
+             3 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             2 | 15 | Portugal    |        3 |    4 |   92 |         10 |      
  6 |       5 | red
+             8 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+             7 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+             9 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+(23 rows)
+</pre>
+
+-# Grouping can be used with class size specification as well. Note below that
+'blue=<Integer>' is the only valid class value since 'blue' is the only class
+value that is present in each group. Further, 'blue=8' will be split between 
the
+four groups, resulting in two blue rows for each group.
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+    'flags',          -- Source table
+    'output_table',   -- Output table
+    'mainhue',        -- Class column
+    'blue=8',         -- Specified class value size. Rest of the values are 
outputed as is.
+    NULL,             -- Output table size
+    'zone'            -- Group by zone
+);
+SELECT * FROM output_table ORDER BY zone, mainhue;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+             2 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+             1 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+             3 |  3 | Austria     |        3 |    1 |   84 |          8 |      
  4 |       2 | red
+             5 |  7 | Denmark     |        3 |    1 |   43 |          5 |      
  6 |       2 | red
+             4 |  6 | China       |        5 |    1 | 9561 |       1008 |      
  7 |       2 | red
+             8 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+             7 | 14 | Norway      |        3 |    1 |  324 |          4 |      
  6 |       3 | red
+             6 | 12 | Luxembourg  |        3 |    1 |    3 |          0 |      
  4 |       3 | red
+             1 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+             2 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+             1 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+             2 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+             3 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+             2 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+             1 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+             5 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+             6 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             3 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             7 | 15 | Portugal    |        3 |    4 |   92 |         10 |      
  6 |       5 | red
+             8 | 16 | Spain       |        3 |    4 |  505 |         38 |      
  2 |       2 | red
+             9 | 19 | UK          |        3 |    4 |  245 |         56 |      
  1 |       3 | red
+            10 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+             4 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+(23 rows)
+</pre>
+
 @anchor literature
 @par Literature
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/sample/test/balance_sample.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/test/balance_sample.sql_in 
b/src/ports/postgres/modules/sample/test/balance_sample.sql_in
index 2249fd6..e99b706 100644
--- a/src/ports/postgres/modules/sample/test/balance_sample.sql_in
+++ b/src/ports/postgres/modules/sample/test/balance_sample.sql_in
@@ -25,58 +25,66 @@ CREATE TABLE "TEST_s"(
     id1 INTEGER,
     id2 INTEGER,
     gr1 INTEGER,
-    gr2 INTEGER
+    gr2 INTEGER,
+    gr3 TEXT
 );
 
 INSERT INTO "TEST_s" VALUES
-(1,0,1,1),
-(2,0,1,1),
-(3,0,1,1),
-(4,0,1,1),
-(5,0,1,1),
-(6,0,1,1),
-(7,0,1,1),
-(8,0,1,1),
-(9,0,1,1),
-(19,0,1,1),
-(29,0,1,1),
-(39,0,1,1),
-(0,1,1,2),
-(0,2,1,2),
-(0,3,1,2),
-(0,4,1,2),
-(0,5,1,2),
-(0,6,1,2),
-(10,10,2,2),
-(20,20,2,2),
-(30,30,2,2),
-(40,40,2,2),
-(50,50,2,2),
-(60,60,2,2),
-(70,70,2,2),
-(10,10,5,5),
-(50,50,5,5),
-(88,88,5,5),
-(40,40,5,6),
-(50,50,5,6),
-(60,60,5,6),
-(70,70,5,6),
-(10,10,6,6),
-(60,60,6,6),
-(30,30,6,6),
-(40,40,6,6),
-(50,50,6,6),
-(60,60,6,6),
-(70,70,6,6),
-(50,50,4,2),
-(60,60,4,2),
-(70,70,4,2),
-(50,50,3,2),
-(60,60,3,2),
-(70,70,3,2),
-(500,50,NULL,2),
-(600,60,NULL,2),
-(700,70,NULL,2)
+(1,0,1,1,'a'),
+(1,0,1,1,'b'),
+(1,0,1,2,'b'),
+(1,0,1,2,'c'),
+(1,0,1,2,'d'),
+(1,0,1,2,'e'),
+(1,0,1,5,'c'),
+(1,0,1,6,'c'),
+(2,0,1,6,'d'),
+(3,0,1,1,'a'),
+(4,0,1,1,'a'),
+(5,0,1,1,'a'),
+(6,0,1,1,'a'),
+(7,0,1,1,'b'),
+(8,0,1,1,'b'),
+(9,0,1,1,'b'),
+(19,0,1,1,'b'),
+(29,0,1,1,'b'),
+(39,0,1,1,'b'),
+(0,1,1,2,'b'),
+(0,2,1,2,'b'),
+(0,3,1,2,'b'),
+(0,4,1,2,'b'),
+(0,5,1,2,'b'),
+(0,6,1,2,'b'),
+(10,10,2,2,'c'),
+(20,20,2,2,'c'),
+(30,30,2,2,'c'),
+(40,40,2,2,'c'),
+(50,50,2,2,'c'),
+(60,60,2,2,'c'),
+(70,70,2,2,'c'),
+(10,10,5,5,'c'),
+(50,50,5,5,'c'),
+(88,88,5,5,'c'),
+(40,40,5,6,'c'),
+(50,50,5,6,'c'),
+(60,60,5,6,'c'),
+(70,70,5,6,'c'),
+(10,10,6,6,'c'),
+(60,60,6,6,'c'),
+(30,30,6,6,'d'),
+(40,40,6,6,'d'),
+(50,50,6,6,'d'),
+(60,60,6,6,'d'),
+(70,70,6,6,'d'),
+(50,50,4,2,'d'),
+(60,60,4,2,'d'),
+(70,70,4,2,'d'),
+(50,50,3,2,'d'),
+(60,60,3,2,'d'),
+(70,70,3,2,'d'),
+(500,50,NULL,2,'e'),
+(600,60,NULL,2,'e'),
+(700,70,NULL,2,'e')
 ;
 
 -- SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1;
@@ -92,14 +100,6 @@ INSERT INTO "TEST_s" VALUES
 -- (6 rows)
 
 SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1;
-
---- Test for random undersampling without replacement
-DROP TABLE IF EXISTS out_s;
-SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', ' undersample', NULL, NULL, 
FALSE);
-SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1;
-SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling gr1') 
FROM
-        (SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1) AS foo WHERE foo.c 
!= 3;
-
 -- --- Test for random undersampling with replacement
 DROP TABLE IF EXISTS out_sr2;
 SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample ', NULL, 
NULL, TRUE, TRUE);
@@ -108,19 +108,39 @@ SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1;
 SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling with 
replacement on gr1') FROM
         (SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1) AS foo WHERE 
foo.c != 3;
 
+--- Test for random undersampling without replacement
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'undersample', NULL, 'gr2, 
gr3', FALSE);
+SELECT * FROM out_s;
+SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3;
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'oversample', NULL, 'gr2, 
gr3', FALSE);
+SELECT * FROM out_s;
+SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3;
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'uniform', NULL, 'gr2, gr3', 
FALSE);
+SELECT * FROM out_s;
+SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3;
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', '1=3', NULL, 'gr2, gr3', 
FALSE);
+SELECT * FROM out_s;
+SELECT gr2, gr3, count(*) AS c FROM out_s GROUP BY gr3, gr2 ORDER BY gr2, gr3;
+-- SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling gr1') 
FROM
+--         (SELECT gr1, gr2, count(*) AS c FROM out_s GROUP BY gr1, gr2) AS 
foo WHERE foo.c != 3;
+
 -- --- Test for random oversampling
 DROP TABLE IF EXISTS out_or3;
 SELECT balance_sample('"TEST_s"', 'out_or3', 'gr1', ' oVEr   ', NULL, NULL);
 SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1;
 SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling') FROM
-        (SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1) AS foo WHERE 
foo.c != 18;
+        (SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1) AS foo WHERE 
foo.c != 25;
 
 --- UNIFORM sampling
 DROP TABLE IF EXISTS out_cd2;
 SELECT balance_sample('"TEST_s"', 'out_cd2', 'gr1', 'Uniform', NULL, NULL);
 SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1;
 SELECT assert(count(*) = 0, 'Wrong number of samples on uniform sampling for 
gr1') FROM
-        (SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1) AS foo WHERE 
foo.c != 8;
+        (SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1) AS foo WHERE 
foo.c != 9;
 
 --- Default sampling should be uniform
 DROP TABLE IF EXISTS out_cd3;
@@ -131,27 +151,27 @@ SELECT assert(count(*) = 0, 'Wrong number of samples on 
uniform sampling for gr1
 
 --- Only one class size is specified
 DROP TABLE IF EXISTS out_cd4;
-SELECT balance_sample('"TEST_s"', 'out_cd4', 'gr1', ' 2 =10', NULL, NULL, 
TRUE);
-SELECT assert(count(*) = 48, 'Wrong number of samples on oversampling with 
comma-delimited list') from out_cd4;
-SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd4 where gr1 = 2) AS foo WHERE foo.c 
!= 10;
-SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd4 where gr1 = 3) AS foo WHERE foo.c 
!= 3;
-SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd4 where gr1 = 1) AS foo WHERE foo.c 
!= 18;
+SELECT balance_sample('"TEST_s"', 'out_cd4', 'gr1', '2=10', NULL, NULL, TRUE);
+SELECT gr1, count(*) AS c FROM out_cd4 GROUP BY gr1;
+SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with 
specified class sizes') from
+out_cd4 where gr1 = 2;
+SELECT assert(count(*) = 3, 'Wrong number of samples on sampling with 
specified class sizes') from
+out_cd4 where gr1 = 3;
+SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with 
specified class sizes') from
+out_cd4 where gr1 = 1;
 
 --- Multiple class sizes with comma delimited string
 DROP TABLE IF EXISTS out_cd5;
 SELECT balance_sample('"TEST_s"', 'out_cd5', 'gr1', '2= 10, 3=6, 1 = 10', 100, 
NULL);
 select gr1, count(*) from out_cd5 group by gr1;
 SELECT assert(count(*) >= 100, 'Wrong number of samples on sampling with 
comma-delimited list') from out_cd5;
-SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd5 where gr1 = 2) AS foo WHERE foo.c 
!= 10;
-SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd5 where gr1 = 4) AS foo WHERE foo.c 
!= 25;
-SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd5 where gr1 = 1) AS foo WHERE foo.c 
!= 10;
-SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd5 where gr1 = 5) AS foo WHERE foo.c 
!= 25;
-SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
-        (SELECT count(*) AS c FROM out_cd5 where gr1 = 3) AS foo WHERE foo.c 
!= 6;
+SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with 
comma-delimited list') from
+out_cd5 where gr1 = 2;
+SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with 
comma-delimited list') from
+out_cd5 where gr1 = 4;
+SELECT assert(count(*) = 10, 'Wrong number of samples on sampling with 
comma-delimited list') from
+out_cd5 where gr1 = 1;
+SELECT assert(count(*) = 25, 'Wrong number of samples on sampling with 
comma-delimited list') from
+out_cd5 where gr1 = 5;
+SELECT assert(count(*) = 6, 'Wrong number of samples on sampling with 
comma-delimited list') from
+out_cd5 where gr1 = 3;

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/utilities.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in 
b/src/ports/postgres/modules/utilities/utilities.py_in
index 8beb701..135404f 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -686,6 +686,7 @@ def strip_end_quotes(input_str, quote='"'):
         return input_str
 # 
------------------------------------------------------------------------------
 
+
 def _grp_null_checks(grp_list):
     """
     Helper function for generating NULL checks for grouping columns
@@ -695,6 +696,7 @@ def _grp_null_checks(grp_list):
     """
     return ' AND '.join([" {i} IS NOT NULL ".format(**locals())
                          for i in grp_list])
+# 
------------------------------------------------------------------------------
 
 
 def _check_groups(tbl1, tbl2, grp_list):
@@ -708,6 +710,8 @@ def _check_groups(tbl1, tbl2, grp_list):
 
     return ' AND '.join([" {tbl1}.{i} = {tbl2}.{i} ".format(**locals())
                          for i in grp_list])
+# 
------------------------------------------------------------------------------
+
 
 def get_filtered_cols_subquery_str(include_from_table, exclude_from_table,
                                    filter_cols_list):
@@ -727,6 +731,8 @@ def get_filtered_cols_subquery_str(include_from_table, 
exclude_from_table,
                 (SELECT {cols}
                  FROM {exclude_from_table})
            """.format(**locals())
+# 
------------------------------------------------------------------------------
+
 
 def get_table_qualified_col_str(tbl_name, col_list):
     """
@@ -746,9 +752,7 @@ def get_grouping_col_str(schema_madlib, module_name, 
reserved_cols,
                           _string_to_array_with_quotes(grouping_col),
                           module_name)
         intersect = frozenset(
-            _string_to_array(grouping_col)).intersection(
-                frozenset(
-                    (reserved_cols)))
+            
_string_to_array(grouping_col)).intersection(frozenset(reserved_cols))
         _assert(len(intersect) == 0,
                 "{0} error: Conflicting grouping column name.\n"
                 "Some predefined keyword(s) ({1}) are not allowed "
@@ -767,6 +771,20 @@ def get_grouping_col_str(schema_madlib, module_name, 
reserved_cols,
     return grouping_str, grouping_col
 # 
------------------------------------------------------------------------------
 
+
+def collate_plpy_result(plpy_result_rows):
+    if not plpy_result_rows:
+        return {}
+    else:
+        all_keys = plpy_result_rows[0].keys()
+        result = collections.defaultdict(list)
+        for each_row in plpy_result_rows:
+            for each_key in all_keys:
+                result[each_key].append(each_row[each_key])
+    return result
+# 
------------------------------------------------------------------------------
+
+
 import unittest
 
 
@@ -827,6 +845,22 @@ class UtilitiesTestCase(unittest.TestCase):
         self.assertEqual(['"a^5,6"', 'b', 'c'], 
split_quoted_delimited_str('"a^5,6",    b, c', quote='"'))
         self.assertEqual(['"A""^5,6"', 'b', 'c'], 
split_quoted_delimited_str('"A""^5,6",    b, c', quote='"'))
 
+    def test_collate_plpy_result(self):
+        plpy_result1 = [{'classes': '4', 'class_count': 3},
+                        {'classes': '1', 'class_count': 18},
+                        {'classes': '5', 'class_count': 7},
+                        {'classes': '3', 'class_count': 3},
+                        {'classes': '6', 'class_count': 7},
+                        {'classes': '2', 'class_count': 7}]
+        self.assertEqual(collate_plpy_result(plpy_result1),
+                         {'classes': ['4', '1', '5', '3', '6', '2'],
+                          'class_count': [3, 18, 7, 3, 7, 7]})
+        self.assertEqual(collate_plpy_result([]), {})
+        self.assertEqual(collate_plpy_result([{'class': 'a'},
+                                              {'class': 'b'},
+                                              {'class': 'c'}]),
+                         {'class': ['a', 'b', 'c']})
+
 
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/utilities.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in 
b/src/ports/postgres/modules/utilities/utilities.sql_in
index 0ec864d..6fdb463 100644
--- a/src/ports/postgres/modules/utilities/utilities.sql_in
+++ b/src/ports/postgres/modules/utilities/utilities.sql_in
@@ -396,12 +396,13 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', 
`');
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.bool_to_text (BOOLEAN)
 RETURNS TEXT
 STRICT
-LANGUAGE SQL AS '
+LANGUAGE SQL AS $$
     SELECT CASE
-        WHEN $1 THEN ''t''
-        ELSE ''f''
+        WHEN $1
+        THEN 'true'
+        ELSE 'false'
     END;
-'m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
+$$m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
 
 ------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/0bfcaf5c/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in 
b/src/ports/postgres/modules/utilities/validate_args.py_in
index a2f43fd..ec1f410 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -21,6 +21,8 @@ object, as well as in any subsequent references to that 
object (e.g., in SELECT,
 DELETE, or UPDATE statements).
 """
 
+m4_changequote(`<!', `!>')
+
 
 def unquote_ident(input_str):
     """
@@ -504,6 +506,7 @@ def explicit_bool_to_text(tbl, cols, schema_madlib):
     """
     Patch madlib.bool_to_text for columns that are of type boolean.
     """
+    m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!>, <!return cols!>, <!!>)
     col_to_type = dict(get_cols_and_types(tbl))
     patched = []
     for col in cols:

Reply via email to