Repository: madlib
Updated Branches:
  refs/heads/master a8bbe082c -> c51da40a1


Balanced Datasets: re-sampling technique

JIRA: MADLIB-1168

* Balanced datasets Phase 1 and Phase 2 implementation which perform balanced
sampling using one of the following strategies:
   1. Undersample: all class levels are sampled as many times as the
                   level with the minimum number of rows. Levels can be
                   sampled with or without replacement.
   2. Oversample: all class levels are sampled as many times as the
                  level with the highest number of rows. This is always
                  with replacement.
   3. Uniform: All class levels will be sampled the same number of
               times. The final output table size is determined by the
               output_size param. Some levels will be sampled with
               replacement, while others will be sampled without.
   4. User defined class level sizes:
           Re-sampling given comma-delimited string of specific class and
           respective sample sizes.
* Install Check
* Documentation

Closes #230

Co-authored-by: Rahul Iyer <ri...@apache.org>
Co-authored-by: Swati Soni <soniswati.2...@gmail.com>
Co-authored-by: Jingyi Mei <j...@pivotal.io>
Co-authored-by: Orhan Kislal <okis...@pivotal.io>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/c51da40a
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/c51da40a
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/c51da40a

Branch: refs/heads/master
Commit: c51da40a104527e6e184e49b6bfe166e176a9958
Parents: a8bbe08
Author: Nandish Jayaram <njaya...@apache.org>
Authored: Wed Jan 10 12:07:36 2018 -0800
Committer: Nandish Jayaram <njaya...@apache.org>
Committed: Thu Feb 8 17:03:42 2018 -0800

----------------------------------------------------------------------
 doc/mainpage.dox.in                             |   4 +-
 .../modules/sample/balance_sample.py_in         | 811 +++++++++++++++++++
 .../modules/sample/balance_sample.sql_in        | 647 +++++++++++++++
 .../modules/sample/test/balance_sample.sql_in   | 157 ++++
 4 files changed, 1618 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/c51da40a/doc/mainpage.dox.in
----------------------------------------------------------------------
diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in
index 4a58e30..ad0a9d0 100644
--- a/doc/mainpage.dox.in
+++ b/doc/mainpage.dox.in
@@ -266,9 +266,11 @@ Contains graph algorithms.
     @defgroup grp_sampling Sampling
     @ingroup grp_utility_functions
     @{A collection of methods for sampling from a population. @}
-        @defgroup grp_strs Stratified Sampling
+        @defgroup grp_balance_sampling Balanced Sampling
         @ingroup grp_sampling
 
+        @defgroup grp_strs Stratified Sampling
+        @ingroup grp_sampling
 
     @defgroup grp_sessionize Sessionize
     @ingroup grp_utility_functions

http://git-wip-us.apache.org/repos/asf/madlib/blob/c51da40a/src/ports/postgres/modules/sample/balance_sample.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/balance_sample.py_in 
b/src/ports/postgres/modules/sample/balance_sample.py_in
new file mode 100644
index 0000000..a32e8e7
--- /dev/null
+++ b/src/ports/postgres/modules/sample/balance_sample.py_in
@@ -0,0 +1,811 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file EXCEPT in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+
+import math
+
+if __name__ != "__main__":
+    import plpy
+    from utilities.control import MinWarning
+    from utilities.utilities import _assert
+    from utilities.utilities import extract_keyvalue_params
+    from utilities.utilities import unique_string
+    from utilities.validate_args import columns_exist_in_table
+    from utilities.validate_args import get_cols
+    from utilities.validate_args import table_exists
+    from utilities.validate_args import table_is_empty
+else:
+    # Used only for Unit Testing
+    # FIXME: repeating a function from utilities that is needed by the unit 
test.
+    # This should be removed once a unittest framework in used for testing.
+    import random
+    import time
+
+    def unique_string(desp='', **kwargs):
+        """
+        Generate random remporary names for temp table and other names.
+        It has a SQL interface so both SQL and Python functions can call it.
+        """
+        r1 = random.randint(1, 100000000)
+        r2 = int(time.time())
+        r3 = int(time.time()) % random.randint(1, 100000000)
+        u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + 
str(r3) + "__"
+        return u_string
+# 
------------------------------------------------------------------------------
+
+UNIFORM = 'uniform'
+UNDERSAMPLE = 'undersample'
+OVERSAMPLE = 'oversample'
+NOSAMPLE = 'nosample'
+
+NEW_ID_COLUMN = '__madlib_id__'
+NULL_IDENTIFIER = '__madlib_null_id__'
+
+def _get_level_frequency_distribution(source_table, class_col):
+    """ Returns a dict containing the number of rows associated with each class
+        level. Each class level count is converted to a string using ::text.
+        None is a valid key in this dict, capturing NULL value in the database.
+    """
+    query_result = plpy.execute("""
+                    SELECT {class_col}::text AS classes,
+                           count(*) AS class_count
+                    FROM {source_table}
+                    GROUP BY {class_col}
+                 """.format(**locals()))
+    actual_level_counts = {}
+    for each_row in query_result:
+        level = each_row['classes']
+        if level:
+            level = level.strip()
+        actual_level_counts[level] = each_row['class_count']
+    return actual_level_counts
+
+
+def _validate_and_get_sampling_strategy(sampling_strategy_str, 
output_table_size,
+                            supported_strategies=None, default=UNIFORM):
+    """ Returns the sampling strategy based on the class_sizes input param.
+        @param sampling_strategy_str The sampling strategy specified by the
+                                         user (class_sizes param)
+        @returns:
+            Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
+    """
+    if not sampling_strategy_str:
+        sampling_strategy_str = default
+    else:
+        if len(sampling_strategy_str) < 3:
+            # Require at least 3 characters since UNIFORM and UNDERSAMPLE have
+            # common prefix substring
+            plpy.error("Sample: Invalid class_sizes parameter")
+
+        if not supported_strategies:
+            supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+        try:
+            # allow user to specify a prefix substring of
+            # supported strategies.
+            sampling_strategy_str = next(x for x in supported_strategies
+                                         if 
x.startswith(sampling_strategy_str.lower()))
+        except StopIteration:
+            # next() returns a StopIteration if no element found
+            plpy.error("Sample: Invalid class_sizes parameter: "
+                       "{0}. Supported class_size parameters are ({1})"
+                       .format(sampling_strategy_str, 
','.join(sorted(supported_strategies))))
+
+    _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, 
OVERSAMPLE) or
+            (sampling_strategy_str.find('=') > 0),
+            "Sample: Invalid class_sizes parameter: "
+            "{0}. Supported class_size parameters are ({1})"
+            .format(sampling_strategy_str, 
','.join(sorted(supported_strategies))))
+
+    _assert(not(sampling_strategy_str.lower() == 'oversample' and 
output_table_size),
+            "Sample: Cannot set output_table_size with oversampling.")
+
+    _assert(not(sampling_strategy_str.lower() == 'undersample' and 
output_table_size),
+            "Sample: Cannot set output_table_size with undersampling.")
+
+    return sampling_strategy_str
+# 
------------------------------------------------------------------------------
+
+
+def _choose_strategy(actual_count, desired_count):
+    """ Choose sampling strategy by comparing actual and desired sample counts
+
+    @param actual_count: Actual number of samples for some level
+    @param desired_count: Desired number of sample for the level
+    @returns:
+        Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
+    """
+    # OVERSAMPLE when the actual count is less than the desired count
+    # UNDERSAMPLE when the actual count is more than the desired count
+
+    # If the actual count for a class level is the same as desired count, then
+    # we could potentially return the input rows as is.  This, however,
+    # precludes the case of bootstrapping (i.e. returning same  number of rows
+    # but after sampling with replacement).  Hence, we treat the actual=desired
+    # as UNDERSAMPLE.  It's specifically set to UNDERSAMPLE since it provides
+    # both 'with' and 'without' replacement  (OVERSAMPLE is always with
+    # replacement and NOSAMPLE is always without replacement)
+    if actual_count < desired_count:
+        return OVERSAMPLE
+    else:
+        return UNDERSAMPLE
+# -------------------------------------------------------------------------
+
+def _get_desired_target_level_counts(desired_level_counts, actual_level_counts,
+                                     output_table_size):
+    """ Returns the target level counts for each class, based on the user
+        defined class_size string. The strategy of (under)oversampling for each
+        class level is chosen based on the desired number of counts for a 
level,
+        and the actual number of counts already present in the input table.
+
+    """
+    target_level_counts = {}
+    for each_level, desired_count in desired_level_counts.items():
+        sample_strategy = _choose_strategy(actual_level_counts[each_level],
+                                           desired_count)
+        target_level_counts[each_level] = (desired_count, sample_strategy)
+
+    remaining_levels = (set(actual_level_counts.keys()) -
+                        set(desired_level_counts.keys()))
+    #   if 'output_table_size' = NULL, unspecified level counts remain as is
+    #   if 'output_table_size' = <Integer>, divide remaining row count
+    #                             uniformly among unspecified level counts
+    if output_table_size:
+        # Uniformly distribute across the remaining class levels
+        remaining_rows = output_table_size - sum(desired_level_counts.values())
+        if remaining_rows > 0:
+            rows_per_level = math.ceil(float(remaining_rows) /
+                                       len(remaining_levels))
+            for each_level in remaining_levels:
+                sample_strategy = _choose_strategy(
+                    actual_level_counts[each_level], rows_per_level)
+                target_level_counts[each_level] = (rows_per_level,
+                                                   sample_strategy)
+    else:
+        # When output_table_size is unspecified, rows from the input table
+        # are sampled as is for remaining class levels. This is same as the
+        # NOSAMPLE strategy.
+        for each_level in remaining_levels:
+            target_level_counts[each_level] = (actual_level_counts[each_level],
+                                                    NOSAMPLE)
+    return target_level_counts
+# -------------------------------------------------------------------------
+
+def _get_supported_target_level_counts(sampling_strategy_str, 
actual_level_counts,
+                                       output_table_size):
+    """ Returns the target level counts for all levels when the class_size 
param
+        is one of [uniform, undersample, oversample]. The strategy of
+        (under)oversampling for a specific class level is chosen based on the
+        computed number of counts for a level, and the actual number of counts
+        already present in the input table.
+
+    """
+    def ceil_of_mean(numbers):
+        return math.ceil(float(sum(numbers)) / max(len(numbers), 1))
+
+    target_level_counts = {}
+    # UNIFORM: Ensure all level counts are same (size determined by 
output_table_size)
+    # UNDERSAMPLE: Ensure all level counts are same as the minimum count
+    # OVERSAMPLE: Ensure all level counts are same as the maximum count
+    size_function = {UNDERSAMPLE: min,
+                     OVERSAMPLE: max,
+                     UNIFORM: ceil_of_mean
+                     }[sampling_strategy_str]
+    if sampling_strategy_str == UNIFORM and output_table_size:
+        # Ignore actual counts for computing target sizes
+        # if output_table_size is specified
+        target_size_per_level = math.ceil(float(output_table_size) /
+                                          len(actual_level_counts))
+    else:
+        target_size_per_level = size_function(actual_level_counts.values())
+    for each_level, actual_count in actual_level_counts.items():
+        sample_strategy = _choose_strategy(actual_count, target_size_per_level)
+        target_level_counts[each_level] = (target_size_per_level,
+                                           sample_strategy)
+    return target_level_counts
+# -------------------------------------------------------------------------
+
+def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
+                             actual_level_counts, output_table_size):
+    """
+    @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, 
None].
+                               This is 'None' only if this is user-defined, 
i.e.,
+                               a comma separated list of class levels and 
number
+                               of rows desired pairs.
+    @param desired_level_counts: Dict that is defined only when the previous 
arg
+                                 sampling_strategy_str is None. This dict would
+                                 then contain the class levels and the
+                                 corresponding number of rows specified by
+                                 the user.
+    @param actual_level_counts: Dict of various class levels and number of rows
+                                  in each of them in the input table
+    @param output_table_size: Size of the desired output table (NULL or 
Integer)
+
+    @returns:
+        Dict. Number of samples to be drawn, and the sampling strategy to be
+              used for each class level.
+    """
+
+    if not sampling_strategy_str:
+        # This case implies user has provided a desired count for one or more
+        # levels. Counts for the rest of the levels depend on 
'output_table_size'.
+        target_level_counts = 
_get_desired_target_level_counts(desired_level_counts,
+                                                               
actual_level_counts,
+                                                               
output_table_size)
+    else:
+        # This case imples the user has chosen one of [uniform, undersample,
+        # oversample] for the class_size param.
+        target_level_counts = 
_get_supported_target_level_counts(sampling_strategy_str,
+                                                                 
actual_level_counts,
+                                                                 
output_table_size)
+    return target_level_counts
+
+# -------------------------------------------------------------------------
+
+
+def _get_sampling_strategy_specific_dict(target_class_sizes):
+    """ Return three dicts, one each for undersampling, oversampling, and
+        nosampling. The dict contains the number of samples to be drawn for
+        each class level.
+    """
+    undersample_level_dict = {}
+    oversample_level_dict = {}
+    nosample_level_dict = {}
+    for level, (count, strategy) in target_class_sizes.items():
+        if strategy == UNDERSAMPLE:
+            undersample_level_dict[level] = count
+        elif strategy == OVERSAMPLE:
+            oversample_level_dict[level] = count
+        else:
+            nosample_level_dict[level] = count
+    return (undersample_level_dict, oversample_level_dict, nosample_level_dict)
+# 
------------------------------------------------------------------------------
+
+
+def _get_nosample_subquery(source_table, class_col, nosample_levels):
+    """ Return the subquery for fetching all rows as is from the input table
+        for specific class levels.
+    """
+    if not nosample_levels:
+        return ''
+    subquery = """
+                SELECT *
+                FROM {0}
+                WHERE {1} in ({2}) OR {1} IS NULL
+            """.format(source_table, class_col,
+                       ','.join(["'{0}'".format(level)
+                                for level in nosample_levels if level]))
+    return subquery
+# 
------------------------------------------------------------------------------
+
+
+def _get_without_replacement_subquery(schema_madlib, source_table,
+                                      source_table_columns, class_col,
+                                      actual_level_counts, 
desired_level_counts):
+    """ Return the subquery for sampling without replacement for specific
+        class levels.
+    """
+    if not desired_level_counts:
+        return ''
+    class_col_tmp = unique_string(desp='class_col')
+    row_number_col = unique_string(desp='row_number')
+    desired_count_col = unique_string(desp='desired_count')
+
+    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
+
+    desired_level_counts_str = "VALUES " + \
+            ','.join("({0}, {1})".
+            format("'{0}'::text".format(k) if k else null_value_string, v)
+            for k, v in desired_level_counts.items())
+    # Subquery q2 is used to figure out the number of rows to select for each
+    # class level. That is used with row_number to order the input rows 
randomly,
+    # and, trim the results to the_desired_number_of_rows for each class level.
+    # q2:
+    #    The FROM clause contains information that can be used to figure out 
the
+    #    number of rows to generate for a class level. The tuple basically
+    #    contains: (class_level, the_desired_number_of_rows)
+    subquery = """
+            SELECT {source_table_columns}
+            FROM
+                (
+                    SELECT {source_table_columns},
+                           row_number() OVER (PARTITION BY {class_col} ORDER 
BY random()) AS {row_number_col},
+                           {desired_count_col}
+                    FROM
+                    (
+                        SELECT {source_table_columns},
+                               {desired_count_col}
+                        FROM
+                            {source_table} s,
+                            ({desired_level_counts_str})
+                                q({class_col_tmp}, {desired_count_col})
+                        WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_level_val}')
+                    ) q2
+                ) q3
+            WHERE {row_number_col} <= {desired_count_col}
+        """.format(null_level_val=NULL_IDENTIFIER, **locals())
+    return subquery
+# 
------------------------------------------------------------------------------
+
+
+def _get_with_replacement_subquery(schema_madlib, source_table,
+                                   source_table_columns, class_col,
+                                   actual_level_counts, desired_level_counts):
+    """ Return the query for sampling with replacement for specific class
+        levels. Always used for oversampling since oversampling will always 
need
+        to use replacement. Used for under sampling only if with_replacement
+        flag is set to TRUE.
+    """
+    if not desired_level_counts:
+        return ''
+
+    class_col_tmp = unique_string(desp='class_col_with_rep')
+    desired_count_col = unique_string(desp='desired_count_with_rep')
+    actual_count_col = unique_string(desp='actual_count')
+    q1_row_no = unique_string(desp='q1_row')
+    q2_row_no = unique_string(desp='q2_row')
+
+    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
+
+    desired_and_actual_level_counts = "VALUES " + \
+    ','.join("({0}, {1}, {2})".
+             format("'{0}'::text".format(k) if k else null_value_string,
+                v, actual_level_counts[k])
+             for k, v in desired_level_counts.items())
+    # q1 and q2 are two sub queries we create to generate the required number 
of
+    # rows per class level.
+    # q1:
+    #    The FROM clause contains information that can be used to figure out 
the
+    #    number of rows to generate for a class level. The tuple basically
+    #    contains:
+    #      (class_level, the_desired_number_of_rows, 
actual_rows_for_level_in_input_table)
+    #    The SELECT clause uses generate series to duplicate a row {q1_row_no}
+    #    of times, which is a value between 1 and 
actual_rows_for_level_in_input_table
+    # q2:
+    #   Replicates the source_table with row IDs starting from 1 through
+    #   actual_rows_for_level_in_input_table.
+    #
+    # The WHERE clause is used to join the two subqueries to obtain the result.
+    subquery = """
+            SELECT {source_table_columns}
+            FROM
+                (
+                    SELECT
+                         {class_col_tmp},
+                         generate_series(1, {desired_count_col}::int) AS _i,
+                         ((random()*({actual_count_col}-1)+1)::int) AS 
{q1_row_no}
+                    FROM
+                        ({desired_and_actual_level_counts})
+                            q({class_col_tmp}, {desired_count_col}, 
{actual_count_col})
+                ) q1,
+                (
+                    SELECT
+                        *,
+                        row_number() OVER(PARTITION BY {class_col}) AS 
{q2_row_no}
+                    FROM
+                         {source_table}
+                ) q2
+            WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_level_val}') AND
+                  q1.{q1_row_no} = q2.{q2_row_no}
+        """.format(null_level_val=NULL_IDENTIFIER, **locals())
+    return subquery
+# 
------------------------------------------------------------------------------
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+                   class_sizes, output_table_size, grouping_cols,
+                   with_replacement, keep_null, **kwargs):
+    """
+    Balance sampling function
+    Args:
+        @param schema_madlib      Schema that MADlib is installed on.
+        @param source_table       Input table name.
+        @param output_table       Output table name.
+        @param class_col          Name of the column containing the class to be
+                                  balanced.
+        @param class_sizes        Parameter to define the size of the different
+                                  class values.
+        @param output_table_size  Desired size of the output data set.
+        @param grouping_cols      The columns that define the grouping.
+        @param with_replacement   The sampling method.
+        @param keep_null          Flag to include rows with class level values
+                                  NULL. Default is False.
+
+    """
+    with MinWarning("warning"):
+
+        desired_sample_per_class = 
unique_string(desp='desired_sample_per_class')
+        desired_counts = unique_string(desp='desired_counts')
+
+        # set all default values
+        if not class_sizes:
+            class_sizes = UNIFORM
+        if not with_replacement:
+            with_replacement = False
+        keep_null = False if not keep_null else True
+        if class_sizes:
+            class_sizes = class_sizes.strip()
+
+        _validate_strs(source_table, output_table, class_col,
+                       output_table_size, grouping_cols)
+        source_table_columns = ','.join(get_cols(source_table))
+
+        new_source_table = source_table
+        # If keep_null=False, create a view of the input table ignoring NULL
+        # values for class levels.
+        if not keep_null:
+            new_source_table = unique_string(desp='source_table')
+            plpy.execute("""
+                        CREATE VIEW {new_source_table} AS
+                        SELECT * FROM {source_table}
+                        WHERE {class_col} IS NOT NULL
+                    """.format(**locals()))
+        actual_level_counts = 
_get_level_frequency_distribution(new_source_table,
+                                                                class_col)
+        # class_sizes can be of two forms:
+        #   1. A string describing sampling strategy (as described in
+        #       _validate_and_get_sampling_strategy).
+        #       In this case, 'sampling_strategy_str' is set to one of
+        #       [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+        #   2. Class sizes for all (or a subset) of the class levels
+        #       In this case, sampling_strategy_str = None and 
parsed_class_sizes
+        #       is used for the sampling.
+        parsed_class_sizes = extract_keyvalue_params(class_sizes,
+                                                     allow_duplicates=False,
+                                                     lower_case_names=False)
+        if not parsed_class_sizes:
+            sampling_strategy_str = 
_validate_and_get_sampling_strategy(class_sizes,
+                                        output_table_size)
+        else:
+            sampling_strategy_str = None
+            try:
+                for each_level, each_class_size in parsed_class_sizes.items():
+                    _assert(each_level in actual_level_counts,
+                            "Sample: Invalid class value specified ({0})".
+                                       format(each_level))
+                    each_class_size = int(each_class_size)
+                    _assert(each_class_size >= 1,
+                            "Sample: Class size has to be greater than zero")
+                    parsed_class_sizes[each_level] = each_class_size
+
+            except TypeError:
+                plpy.error("Sample: Invalid value for class_sizes ({0})".
+                           format(class_sizes))
+
+        # Get the number of rows to be sampled for each class level, based on
+        # the input table, class_sizes, and output_table_size params. This also
+        # includes info about the resulting sampling strategy, i.e., one of
+        # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
+        target_class_sizes = _get_target_level_counts(sampling_strategy_str,
+                                                      parsed_class_sizes,
+                                                      actual_level_counts,
+                                                      output_table_size)
+
+        undersample_level_dict, oversample_level_dict, nosample_level_dict = \
+            _get_sampling_strategy_specific_dict(target_class_sizes)
+
+        # Get subqueries for each sampling strategy, so that they can be used
+        # together in one big query.
+
+        # Subquery that will be used to get rows as is for those class levels
+        # that need no sampling.
+        nosample_subquery = _get_nosample_subquery(
+            new_source_table, class_col, nosample_level_dict.keys())
+        # Subquery that will be used to sample those class levels that
+        # have to be oversampled.
+        oversample_subquery = _get_with_replacement_subquery(
+            schema_madlib, new_source_table, source_table_columns, class_col,
+            actual_level_counts, oversample_level_dict)
+        # Subquery that will be used to sample those class levels that
+        # have to be undersampled. Undersampling supports both with and without
+        # replacement, so fetch the appropriate subquery.
+        if with_replacement:
+            undersample_subquery = _get_with_replacement_subquery(
+                schema_madlib, new_source_table, source_table_columns, 
class_col,
+                actual_level_counts, undersample_level_dict)
+        else:
+            undersample_subquery = _get_without_replacement_subquery(
+                schema_madlib, new_source_table, source_table_columns, 
class_col,
+                actual_level_counts, undersample_level_dict)
+
+        # Merge the three subqueries using a UNION ALL clause.
+        union_all_subquery = ' UNION ALL '.join(
+            ['({0})'.format(subquery)
+             for subquery in [undersample_subquery, oversample_subquery, 
nosample_subquery]
+             if subquery])
+
+        final_query = """
+                CREATE TABLE {output_table} AS
+                SELECT row_number() OVER() AS {new_col_name}, *
+                FROM (
+                    {union_all_subquery}
+                ) union_query
+            """.format(new_col_name=NEW_ID_COLUMN, **locals())
+        plpy.execute(final_query)
+        if not keep_null:
+            plpy.execute("DROP VIEW {0}".format(new_source_table))
+
+
+def _validate_strs(source_table, output_table, class_col, output_table_size,
+                    grouping_cols):
+    _assert(source_table and table_exists(source_table),
+            "Sample: Source table ({source_table}) does not 
exist.".format(**locals()))
+    _assert(not table_is_empty(source_table),
+            "Sample: Source table ({source_table}) is 
empty.".format(**locals()))
+
+    _assert(output_table,
+            "Sample: Output table name is missing.".format(**locals()))
+    _assert(not table_exists(output_table),
+            "Sample: Output table ({output_table}) already 
exists.".format(**locals()))
+
+    _assert(class_col,
+            "Sample: Class column name is missing.".format(**locals()))
+    _assert(columns_exist_in_table(source_table, [class_col]),
+            ("""Sample: Class column ({class_col}) does not exist in""" +
+             """ table ({source_table}).""").format(**locals()))
+
+    _assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
+            ("""Sample: Please ensure the source table ({0})""" +
+             """ does not contain a column named {1}""").format(source_table, 
NEW_ID_COLUMN))
+
+    _assert((not output_table_size) or (output_table_size > 0),
+            "Sample: Invalid output table size ({output_table_size}).".format(
+            **locals()))
+
+    _assert(grouping_cols is None,
+            "grouping_cols is not supported at the moment."
+            .format(**locals()))
+
+
+def balance_sample_help(schema_madlib, message, **kwargs):
+    """
+    Help function for balance_sample
+
+    Args:
+        @param schema_madlib
+        @param message: string, Help message string
+        @param kwargs
+
+    Returns:
+        String. Help/usage information
+    """
+    if not message:
+        help_string = """
+-----------------------------------------------------------------------
+                            SUMMARY
+-----------------------------------------------------------------------
+Given a table with varying set of records for each class label,
+this function will create an output table with a varying types (by
+default: uniform) of sampling distributions of each class label. It is
+possible to use with or without replacement sampling methods, specify
+different proportions of each class, multiple grouping columns and/or
+output table size.
+
+For more details on function usage:
+    SELECT {schema_madlib}.balance_sample('usage');
+    SELECT {schema_madlib}.balance_sample('example');
+            """
+    elif message.lower() in ['usage', 'help', '?']:
+        help_string = """
+
+Given a table, stratified sampling returns a proportion of records for
+each group (strata). It is possible to use with or without replacement
+sampling methods, specify a set of target columns, and assume the
+whole table is a single strata.
+
+----------------------------------------------------------------------------
+                            USAGE
+----------------------------------------------------------------------------
+
+ SELECT {schema_madlib}.balance_sample(
+    source_table      TEXT,     -- Input table name.
+    output_table      TEXT,     -- Output table name.
+    class_col         TEXT,     -- Name of column containing the class to be
+                                -- balanced.
+    class_size        TEXT,     -- (Default: NULL) Parameter to define the size
+                                -- of the different class values.
+    output_table_size INTEGER,  -- (Default: NULL) Desired size of the output
+                                -- data set.
+    grouping_cols     TEXT,     -- (Default: NULL) The columns columns that
+                                -- defines the grouping.
+    with_replacement  BOOLEAN   -- (Default: FALSE) The sampling method.
+    keep_null         BOOLEAN   -- (Default: FALSE) Consider class levels with
+                                    NULL values or not.
+
+If class_size is NULL, the source table is uniformly sampled.
+
+If output_table_size is NULL, the resulting output table size will depend on
+the settings for the ‘class_size’ parameter. It is ignored if 
‘class_size’
+parameter is set to either ‘oversample’ or ‘undersample’.
+
+If grouping_cols is NULL, the whole table is treated as a single group and
+sampled accordingly.
+
+If with_replacement is TRUE, each sample is independent (the same row may
+be selected in the sample set more than once). Else (if with_replacement
+is FALSE), a row can be selected at most once.
+);
+
+The output_table would contain the required number of samples, along with a
+new column named __madlib_id__, that contain unique numbers for all
+sampled rows.
+"""
+    elif message.lower() in ("example", "examples"):
+        help_string = """
+----------------------------------------------------------------------------
+                                EXAMPLES
+----------------------------------------------------------------------------
+
+-- Create an input table
+DROP TABLE IF EXISTS test;
+
+CREATE TABLE test(
+    id1 INTEGER,
+    id2 INTEGER,
+    gr1 INTEGER,
+    gr2 INTEGER
+);
+
+INSERT INTO test VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2)
+;
+
+-- Sample without replacement
+DROP TABLE IF EXISTS out;
+SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL, FALSE);
+SELECT * FROM out;
+
+--- Sample with replacement
+DROP TABLE IF EXISTS out_sr2;
+SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL, TRUE);
+SELECT * FROM out;
+"""
+    else:
+        help_string = "No such option. Use {schema_madlib}.balance_sample()"
+
+    return help_string.format(schema_madlib=schema_madlib)
+
+
+import unittest
+
+
+class UtilitiesTestCase(unittest.TestCase):
+    """
+        Comment "import plpy" and replace plpy.error calls with appropriate
+        Python Exceptions to successfully run the test cases
+    """
+
+    def setUp(self):
+        self.input_class_level_counts1 = {'a': 20, 'b': 30, 'c': 25}
+        self.level1a = 'a'
+        self.level1a_cnt1 = 15
+        self.level1a_cnt2 = 25
+        self.level1a_cnt3 = 20
+
+        self.sampling_strategy_str0 = ''
+        self.sampling_strategy_str1 = 'uniform'
+        self.sampling_strategy_str2 = 'oversample'
+        self.sampling_strategy_str3 = 'undersample'
+        self.user_specified_class_size0 = ''
+        self.user_specified_class_size1 = {'a': 25, 'b': 25}
+        self.user_specified_class_size2 = {'b': 25}
+        self.user_specified_class_size3 = {'a': 30}
+        self.output_table_size1 = None
+        self.output_table_size2 = 60
+        # self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100}
+
+    def test__choose_strategy(self):
+        self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25))
+        self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25))
+        self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25))
+
+    def test__get_target_level_counts(self):
+        # Test cases for user defined class level samples, without output 
table size
+        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, NOSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size1,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        self.assertEqual({'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, NOSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size2,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': 
(25, NOSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size3,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        # Test cases for user defined class level samples, with output table 
size
+        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(10, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size1,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size2))
+        self.assertEqual({'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(18, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size2,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size2))
+        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 'c': 
(15, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str0,
+                                                  
self.user_specified_class_size3,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size2))
+        # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any 
output table size
+        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 'c': 
(25, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str1,
+                                                  
self.user_specified_class_size0,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 'c': 
(30, OVERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str2,
+                                                  
self.user_specified_class_size0,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': 
(20, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str3,
+                                                  
self.user_specified_class_size0,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size1))
+        # Test cases for UNIFORM with output table size
+        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 'c': 
(20, UNDERSAMPLE)},
+                         _get_target_level_counts(self.sampling_strategy_str1,
+                                                  
self.user_specified_class_size0,
+                                                  
self.input_class_level_counts1,
+                                                  self.output_table_size2))
+
+    def test__get_sampling_strategy_specific_dict(self):
+        # Test cases for getting sampling strategy specific counts
+        target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (25, NOSAMPLE)}
+        target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE)}
+        target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, NOSAMPLE), 
'c': (25, NOSAMPLE)}
+        self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
+                         
_get_sampling_strategy_specific_dict(target_level_counts_1))
+        self.assertEqual(({'b': 25}, {'a': 25}, {}),
+                         
_get_sampling_strategy_specific_dict(target_level_counts_2))
+        self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
+                         
_get_sampling_strategy_specific_dict(target_level_counts_3))
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/madlib/blob/c51da40a/src/ports/postgres/modules/sample/balance_sample.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/balance_sample.sql_in 
b/src/ports/postgres/modules/sample/balance_sample.sql_in
new file mode 100644
index 0000000..cd70961
--- /dev/null
+++ b/src/ports/postgres/modules/sample/balance_sample.sql_in
@@ -0,0 +1,647 @@
+/* ----------------------------------------------------------------------- 
*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *
+ * @file balance_sample.sql_in
+ *
+ * @brief SQL functions for balanced data sets sampling.
+ * @date 12/14/2017
+ *
+ * @sa Given a table, balanced sampling returns a sampled data set
+ * with specified proportions for each class (defaults to uniform sampling).
+ *
+ *//* ----------------------------------------------------------------------- 
*/
+
+m4_include(`SQLCommon.m4')
+
+
+/**
+@addtogroup grp_balance_sampling
+
+<div class="toc"><b>Contents</b>
+<ul>
+<li><a href="#strs">Balanced Sampling</a></li>
+<li><a href="#examples">Examples</a></li>
+<li><a href="#literature">Literature</a></li>
+<li><a href="#related">Related Topics</a></li>
+</ul>
+</div>
+
+@brief A method to independently sample classes to produce a
+balanced data set.
+This is commonly used when classes are imbalanced,
+to ensure that subclasses are adequately represented in the sample.
+
+Some classification algorithms only perform optimally
+when the number of samples in each class is roughly the same.
+Highly skewed datasets are common in many domains (e.g., fraud
+detection), so resampling to offset this imbalance can
+produce a better decision boundary.
+
+This module offers a number of resampling techniques
+including undersampling majority classes,
+oversampling minority classes, and
+combinations of the two.
+
+@anchor strs
+@par Balanced Sampling
+
+<pre class="syntax">
+balance_sample( source_table,
+                output_table,
+                class_col,
+                class_sizes,
+                output_table_size,
+                grouping_cols,
+                with_replacement,
+                keep_null
+              )
+</pre>
+
+\b Arguments
+<dl class="arglist">
+<dt>source_table</dt>
+<dd>TEXT. Name of the table containing the input data.</dd>
+
+<dt>output_table</dt>
+<dd>TEXT. Name of output table that contains the sampled data.
+The output table contains all columns present in the source
+table, plus a new generated id called "__madlib_id__" added as
+the first column. </dd>
+
+<dt>class_col</dt>
+<dd>TEXT,  Name of the column containing the class to be balanced.
+</dd>
+
+<dt>class_sizes (optional)</dt>
+<dd>VARCHAR, default ‘uniform’.  Parameter to define the size
+of the different class values.
+(Class values are sometimes also called levels).
+Can be set to the following:
+
+<ul>
+<li><b>‘uniform’</b>:
+All class values will be resampled to have the same number of rows.
+</li>
+<li><b>'undersample'</b>:
+Undersample such that all class values end up with the same number of
+observations as the minority class.  Done without replacement by default
+unless the parameter ‘with_replacement’ is set to TRUE.
+</li>
+<li><b>'oversample'</b>:
+Oversample with replacement such that all class values end up with the
+same number of observations as the majority class.  Not affected by the
+parameter ‘with_replacement’ since oversampling is always done with
+replacement.
+</li>
+Short forms of the above will work too, e.g., 'uni' works the same
+as 'uniform'.
+</ul>
+
+Alternatively, you can also explicitly set class size in a string containing a
+comma-delimited list. Order does not matter and all class values do not
+need to be specified.  Use the format “class_value_1=x, class_value_2=y, 
…”
+where 'class_value' in the list must exist in the column 'class_col'.
+Set to an integer representing the desired number of observations.
+E.g.,  ‘red=3000, blue=4000’ means you want to resample the dataset
+to result in exactly 3000 red and 4000 blue rows in the ‘output_table’.
+</li>
+</ul>
+
+@note
+The allowed names for class values follows object naming rules in
+PostgreSQL [1].  Quoted identifiers are allowed and should be enclosed
+in double quotes in the usual way.  If for some reason the class values
+in the examples above were “ReD” and “BluE” then the comma delimited
+list for ‘class_size’ would be:  ‘“ReD”=3000, “BluE”=4000’.
+</dd>
+
+<dt>output_table_size (optional)</dt>
+<dd>INTEGER, default NULL.  Desired size of the output data set.
+This parameter is ignored if ‘class_size’ parameter is set to either
+‘oversample’ or ‘undersample’ since output table size is already
+determined.
+If NULL, the resulting output table size will depend on the settings
+for the ‘class_size’ parameter (see table below for more details).
+</dd>
+
+<dt>grouping_cols (optional)</dt>
+<dd>TEXT, default: NULL. A single column or a list of
+comma-separated columns that defines the strata.  When this
+parameter is NULL, no grouping is used so the sampling is
+non-stratified, that is, the whole table is treated as a
+single group.
+
+@note
+Current implementation does not support grouping_cols.  It
+will be added in an upcoming release.
+</dd>
+
+<dt>with_replacement  (optional)</dt>
+<dd>BOOLEAN, default FALSE.  Determines whether to sample
+with replacement or without replacement (default).
+With replacement means that it is possible that the
+same row may appear in the sample set more than once.
+Without replacement means a given row can be selected
+only once. This parameter affects undersampling only since
+oversampling is always done with replacement.</dd>
+
+<dt>keep_null (optional)</dt>
+<dd>BOOLEAN, default FALSE. Determines whether to
+sample rows whose class values are NULL. By default,
+all rows with NULL class values are ignored. If this
+is set to TRUE, then NULL is treated as another class
+value.</dd>
+</dl>
+
+<b>How Output Table Size is Determined</b>
+
+The rule of thumb is that if you specify a value for
+'output_table_size', then you will generally
+get an output table of that size, with some minor
+rounding variations.  If you set 'output_table_size' to NULL,
+then the size of the output table will be calculated
+depending on what you put for the 'class_size' parameter.
+The following table shows how the parameters 'class_size'
+and 'output_table_size' work together:
+
+| Case | 'class_size' | 'output_table_size' | Result |
+| :---- | :---- | :---- | :---- |
+| 1 | 'uniform' | NULL | Resample for uniform class size with output size = 
input size (i.e., balanced). |
+| 2 | 'uniform' | 10000 | Resample for uniform class size with output size = 
10K (i.e., balanced). |
+| 3 | NULL | NULL | Resample for uniform class size with output size = input 
size (i.e., balanced). Class_size=NULL has same behavior as ‘uniform’. |
+| 4 | NULL | 10000 | Resample for uniform class size with output size = 10K 
(i.e., balanced). Class_size=NULL has same behavior as ‘uniform’. |
+| 5 | 'undersample' | n/a | Undersample such that all class values end up with 
the same number of observations as the minority.|
+| 6 | 'oversample' | n/a | Oversample with replacement (always) such that all 
class values end up with the same number of observations as the majority. |
+| 7 | 'red=3000' | NULL | Resample red to 3K, leave rest of the class values 
(blue, green, etc.) as is. |
+| 8 | 'red=3000, blue=4000' | 10000 | Resample red to 3K and blue to 4K, 
divide remaining class values evenly 3K/(n-2) each, where n=number of class 
values.   Note that if red and blue are the only class values, then output 
table size will be 7K not 10K.  (This is the only case where specifying a value 
for 'output_table_size' may not actually result in an output table of that 
size.) |
+
+@anchor examples
+@par Examples
+
+Note that due to the random nature of sampling, your
+results may look different from those below.
+
+-# Create an input table using part of the flags
+data set from https://archive.ics.uci.edu/ml/datasets/Flags :
+<pre class="syntax">
+DROP TABLE IF EXISTS flags;
+CREATE TABLE flags (
+    id INTEGER,
+    name TEXT,
+    landmass INTEGER,
+    zone INTEGER,
+    area INTEGER,
+    population INTEGER,
+    language INTEGER,
+    colours INTEGER,
+    mainhue TEXT
+);
+INSERT INTO flags VALUES
+(1, 'Argentina', 2, 3, 2777, 28, 2, 2, 'blue'),
+(2, 'Australia', 6, 2, 7690, 15, 1, 3, 'blue'),
+(3, 'Austria', 3, 1, 84, 8, 4, 2, 'red'),
+(4, 'Brazil', 2, 3, 8512, 119, 6, 4, 'green'),
+(5, 'Canada', 1, 4, 9976, 24, 1, 2, 'red'),
+(6, 'China', 5, 1, 9561, 1008, 7, 2, 'red'),
+(7, 'Denmark', 3, 1, 43, 5, 6, 2, 'red'),
+(8, 'Greece', 3, 1, 132, 10, 6, 2, 'blue'),
+(9, 'Guatemala', 1, 4, 109, 8, 2, 2, 'blue'),
+(10, 'Ireland', 3, 4, 70, 3, 1, 3, 'white'),
+(11, 'Jamaica', 1, 4, 11, 2, 1, 3, 'green'),
+(12, 'Luxembourg', 3, 1, 3, 0, 4, 3, 'red'),
+(13, 'Mexico', 1, 4, 1973, 77, 2, 4, 'green'),
+(14, 'Norway', 3, 1, 324, 4, 6, 3, 'red'),
+(15, 'Portugal', 3, 4, 92, 10, 6, 5, 'red'),
+(16, 'Spain', 3, 4, 505, 38, 2, 2, 'red'),
+(17, 'Sweden', 3, 1, 450, 8, 6, 2, 'blue'),
+(18, 'Switzerland', 3, 1, 41, 6, 4, 2, 'red'),
+(19, 'UK', 3, 4, 245, 56, 1, 3, 'red'),
+(20, 'USA', 1, 4, 9363, 231, 1, 3, 'white'),
+(21, 'xElba', 3, 1, 1, 1, 6, NULL, NULL),
+(22, 'xPrussia', 3, 1, 249, 61, 4, NULL, NULL);
+</pre>
+
+-# Uniform sampling.  All class values will be resampled
+so that they have the same number of rows. The output data
+size will be the same as the input data size, ignoring
+NULL values.  Uniform sampling
+is the default for the 'class_size' parameter so we do not
+need to explicitly set it:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue');          -- Class column
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+             5 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+             2 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+             3 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+             4 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+             1 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+            11 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            12 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            14 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            15 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            13 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             8 |  3 | Austria     |        3 |    1 |   84 |          8 |      
  4 |       2 | red
+            10 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             9 |  7 | Denmark     |        3 |    1 |   43 |          5 |      
  6 |       2 | red
+             6 | 15 | Portugal    |        3 |    4 |   92 |         10 |      
  6 |       5 | red
+             7 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+            19 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            20 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            18 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            16 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+            17 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+(20 rows)
+</pre>
+Next we do uniform sampling again, but this time we specify a
+size for the output table:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'uniform',           -- Uniform sample
+                               12);                -- Desired output table size
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |   name    | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-----------+----------+------+------+------------+----------+---------+---------
+            10 |  1 | Argentina |        2 |    3 | 2777 |         28 |        
2 |       2 | blue
+            12 |  2 | Australia |        6 |    2 | 7690 |         15 |        
1 |       3 | blue
+            11 |  8 | Greece    |        3 |    1 |  132 |         10 |        
6 |       2 | blue
+             2 |  4 | Brazil    |        2 |    3 | 8512 |        119 |        
6 |       4 | green
+             3 | 11 | Jamaica   |        1 |    4 |   11 |          2 |        
1 |       3 | green
+             1 | 13 | Mexico    |        1 |    4 | 1973 |         77 |        
2 |       4 | green
+             5 |  7 | Denmark   |        3 |    1 |   43 |          5 |        
6 |       2 | red
+             6 | 14 | Norway    |        3 |    1 |  324 |          4 |        
6 |       3 | red
+             4 | 15 | Portugal  |        3 |    4 |   92 |         10 |        
6 |       5 | red
+             9 | 10 | Ireland   |        3 |    4 |   70 |          3 |        
1 |       3 | white
+             7 | 20 | USA       |        1 |    4 | 9363 |        231 |        
1 |       3 | white
+             8 | 20 | USA       |        1 |    4 | 9363 |        231 |        
1 |       3 | white
+(12 rows)
+</pre>
+
+-# Oversampling.  Oversample with replacement such that all
+class values except NULLs end up with the same number of observations as
+the majority class. Countries with red flags is the majority
+class with 10 observations, so other class values will be
+oversampled to 10 observations:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'oversample');       -- Oversample
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+            35 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            33 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            37 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            34 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            36 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            32 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            31 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+            39 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+            38 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+            40 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+            19 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            20 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            12 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+            11 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+            13 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+            17 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            15 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            16 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            18 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+            14 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             9 |  3 | Austria     |        3 |    1 |   84 |          8 |      
  4 |       2 | red
+             8 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             1 |  6 | China       |        5 |    1 | 9561 |       1008 |      
  7 |       2 | red
+            10 |  7 | Denmark     |        3 |    1 |   43 |          5 |      
  6 |       2 | red
+             2 | 12 | Luxembourg  |        3 |    1 |    3 |          0 |      
  4 |       3 | red
+             4 | 14 | Norway      |        3 |    1 |  324 |          4 |      
  6 |       3 | red
+             6 | 15 | Portugal    |        3 |    4 |   92 |         10 |      
  6 |       5 | red
+             3 | 16 | Spain       |        3 |    4 |  505 |         38 |      
  2 |       2 | red
+             5 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+             7 | 19 | UK          |        3 |    4 |  245 |         56 |      
  1 |       3 | red
+            22 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            26 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            24 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            21 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            27 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            25 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            23 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            29 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+            30 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+            28 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+(40 rows)
+</pre>
+
+-# Undersampling.  Undersample such that all class values except NULLs end
+up with the same number of observations as the minority class.
+Countries with white flags is the minority class with 2 observations,
+so other class values will be undersampled to 2 observations:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'undersample');      -- Undersample
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+             1 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+             2 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+             4 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+             3 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             5 | 16 | Spain       |        3 |    4 |  505 |         38 |      
  2 |       2 | red
+             6 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+             8 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+             7 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+(8 rows)
+</pre>
+In the case of bootstrapping, we may want to undersample with replacement,
+so we set the 'with_replacement' parameter to TRUE:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'undersample',       -- Undersample
+                               NULL,               -- Output table size will 
be calculated
+                               NULL,               -- No grouping
+                              'TRUE');             -- Sample with replacement
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |   name    | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-----------+----------+------+------+------------+----------+---------+---------
+             2 |  9 | Guatemala |        1 |    4 |  109 |          8 |        
2 |       2 | blue
+             1 |  9 | Guatemala |        1 |    4 |  109 |          8 |        
2 |       2 | blue
+             3 |  4 | Brazil    |        2 |    3 | 8512 |        119 |        
6 |       4 | green
+             4 | 13 | Mexico    |        1 |    4 | 1973 |         77 |        
2 |       4 | green
+             6 |  5 | Canada    |        1 |    4 | 9976 |         24 |        
1 |       2 | red
+             5 | 19 | UK        |        3 |    4 |  245 |         56 |        
1 |       3 | red
+             7 | 20 | USA       |        1 |    4 | 9363 |        231 |        
1 |       3 | white
+             8 | 20 | USA       |        1 |    4 | 9363 |        231 |        
1 |       3 | white
+(8 rows)
+</pre>
+Note above that some rows may appear multiple times above since we sampled 
with replacement.
+
+-# Setting class size by count.  Here we set the number of rows for
+red and blue flags, and leave green and white flags unchanged:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'red=7, blue=7');    -- Want 7 reds and 7 blues
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name    | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+------------+----------+------+------+------------+----------+---------+---------
+             5 |  2 | Australia  |        6 |    2 | 7690 |         15 |       
 1 |       3 | blue
+             7 |  8 | Greece     |        3 |    1 |  132 |         10 |       
 6 |       2 | blue
+             6 |  8 | Greece     |        3 |    1 |  132 |         10 |       
 6 |       2 | blue
+             1 |  9 | Guatemala  |        1 |    4 |  109 |          8 |       
 2 |       2 | blue
+             3 | 17 | Sweden     |        3 |    1 |  450 |          8 |       
 6 |       2 | blue
+             2 | 17 | Sweden     |        3 |    1 |  450 |          8 |       
 6 |       2 | blue
+             4 | 17 | Sweden     |        3 |    1 |  450 |          8 |       
 6 |       2 | blue
+             8 |  4 | Brazil     |        2 |    3 | 8512 |        119 |       
 6 |       4 | green
+            18 | 11 | Jamaica    |        1 |    4 |   11 |          2 |       
 1 |       3 | green
+            19 | 13 | Mexico     |        1 |    4 | 1973 |         77 |       
 2 |       4 | green
+            13 |  3 | Austria    |        3 |    1 |   84 |          8 |       
 4 |       2 | red
+            14 |  5 | Canada     |        1 |    4 | 9976 |         24 |       
 1 |       2 | red
+            17 |  6 | China      |        5 |    1 | 9561 |       1008 |       
 7 |       2 | red
+            15 | 12 | Luxembourg |        3 |    1 |    3 |          0 |       
 4 |       3 | red
+            16 | 14 | Norway     |        3 |    1 |  324 |          4 |       
 6 |       3 | red
+            11 | 15 | Portugal   |        3 |    4 |   92 |         10 |       
 6 |       5 | red
+            12 | 16 | Spain      |        3 |    4 |  505 |         38 |       
 2 |       2 | red
+             9 | 10 | Ireland    |        3 |    4 |   70 |          3 |       
 1 |       3 | white
+            10 | 20 | USA        |        1 |    4 | 9363 |        231 |       
 1 |       3 | white
+(19 rows)
+</pre>
+Next we set the number of rows for red and blue flags, and also set an
+output table size.  This means that green and white flags will be
+uniformly sampled to get to the desired output table size:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                              'red=7, blue=7',     -- Want 7 reds and 7 blues
+                               22);                -- Desired output table size
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+            16 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            20 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+            21 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+            22 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+            18 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+            19 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+            17 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+             9 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            10 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+             8 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+            11 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             6 |  3 | Austria     |        3 |    1 |   84 |          8 |      
  4 |       2 | red
+             7 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             2 |  7 | Denmark     |        3 |    1 |   43 |          5 |      
  6 |       2 | red
+             1 | 12 | Luxembourg  |        3 |    1 |    3 |          0 |      
  4 |       3 | red
+             3 | 15 | Portugal    |        3 |    4 |   92 |         10 |      
  6 |       5 | red
+             5 | 16 | Spain       |        3 |    4 |  505 |         38 |      
  2 |       2 | red
+             4 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+            14 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            13 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            15 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            12 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+(22 rows)
+</pre>
+
+-# To make NULL a valid class value, set the parameter to keep NULLs:
+<pre class="syntax">
+DROP TABLE IF EXISTS output_table;
+SELECT madlib.balance_sample(
+                              'flags',             -- Source table
+                              'output_table',      -- Output table
+                              'mainhue',           -- Class column
+                               NULL,               -- Uniform
+                               NULL,               -- Output table size
+                               NULL,               -- No grouping
+                               NULL,               -- Sample without 
replacement
+                              'TRUE');             -- Make NULLs a valid class 
value
+SELECT * FROM output_table ORDER BY mainhue, name;
+</pre>
+<pre class="result">
+ __madlib_id__ | id |    name     | landmass | zone | area | population | 
language | colours | mainhue
+---------------+----+-------------+----------+------+------+------------+----------+---------+---------
+            25 |  1 | Argentina   |        2 |    3 | 2777 |         28 |      
  2 |       2 | blue
+            22 |  2 | Australia   |        6 |    2 | 7690 |         15 |      
  1 |       3 | blue
+            24 |  8 | Greece      |        3 |    1 |  132 |         10 |      
  6 |       2 | blue
+            21 |  9 | Guatemala   |        1 |    4 |  109 |          8 |      
  2 |       2 | blue
+            23 | 17 | Sweden      |        3 |    1 |  450 |          8 |      
  6 |       2 | blue
+             7 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+             6 |  4 | Brazil      |        2 |    3 | 8512 |        119 |      
  6 |       4 | green
+            10 | 11 | Jamaica     |        1 |    4 |   11 |          2 |      
  1 |       3 | green
+             8 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             9 | 13 | Mexico      |        1 |    4 | 1973 |         77 |      
  2 |       4 | green
+             3 |  3 | Austria     |        3 |    1 |   84 |          8 |      
  4 |       2 | red
+             1 |  5 | Canada      |        1 |    4 | 9976 |         24 |      
  1 |       2 | red
+             2 | 16 | Spain       |        3 |    4 |  505 |         38 |      
  2 |       2 | red
+             4 | 18 | Switzerland |        3 |    1 |   41 |          6 |      
  4 |       2 | red
+             5 | 19 | UK          |        3 |    4 |  245 |         56 |      
  1 |       3 | red
+            13 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            11 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            14 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            12 | 10 | Ireland     |        3 |    4 |   70 |          3 |      
  1 |       3 | white
+            15 | 20 | USA         |        1 |    4 | 9363 |        231 |      
  1 |       3 | white
+            17 | 21 | xElba       |        3 |    1 |    1 |          1 |      
  6 |         |
+            18 | 21 | xElba       |        3 |    1 |    1 |          1 |      
  6 |         |
+            16 | 21 | xElba       |        3 |    1 |    1 |          1 |      
  6 |         |
+            20 | 22 | xPrussia    |        3 |    1 |  249 |         61 |      
  4 |         |
+            19 | 22 | xPrussia    |        3 |    1 |  249 |         61 |      
  4 |         |
+(25 rows)
+</pre>
+
+@anchor literature
+@par Literature
+
+[1] Object naming in PostgreSQL
+https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
+
+@anchor related
+@par Related Topics
+
+File balance_sample.sql_in for list of functions and usage.
+
+*/
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT,
+  class_sizes        VARCHAR,
+  output_table_size  INTEGER,
+  grouping_cols      TEXT,
+  with_replacement   BOOLEAN,
+  keep_null          BOOLEAN
+) RETURNS VOID AS $$
+    PythonFunction(sample, balance_sample, balance_sample)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+-------------------------------------------------------------------------------
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT,
+  class_sizes        VARCHAR,
+  output_table_size  INTEGER,
+  grouping_cols      TEXT,
+  with_replacement   BOOLEAN
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.balance_sample($1, $2, $3, $4, $5, $6, $7, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT,
+  class_sizes        VARCHAR,
+  output_table_size  INTEGER,
+  grouping_cols      TEXT
+) RETURNS VOID AS $$
+     SELECT MADLIB_SCHEMA.balance_sample($1, $2, $3, $4, $5, $6, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT,
+  class_sizes        VARCHAR,
+  output_table_size  INTEGER
+) RETURNS VOID AS $$
+     SELECT MADLIB_SCHEMA.balance_sample($1, $2, $3, $4, $5, NULL, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT,
+  class_sizes        VARCHAR
+) RETURNS VOID AS $$
+     SELECT MADLIB_SCHEMA.balance_sample($1, $2, $3, $4, NULL, NULL, NULL, 
NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+  source_table       TEXT,
+  output_table       TEXT,
+  class_col          TEXT
+) RETURNS VOID AS $$
+     SELECT MADLIB_SCHEMA.balance_sample($1, $2, $3, 'uniform', NULL, NULL, 
NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+-------------------------------------------------------------------------------
+
+-- Online help
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample(
+    message VARCHAR
+) RETURNS VARCHAR AS $$
+    PythonFunction(sample, balance_sample, balance_sample_help)
+$$ LANGUAGE plpythonu IMMUTABLE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+-------------------------------------------------------------------------------
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.balance_sample()
+RETURNS VARCHAR AS $$
+    SELECT MADLIB_SCHEMA.balance_sample('');
+$$ LANGUAGE sql IMMUTABLE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+-------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/madlib/blob/c51da40a/src/ports/postgres/modules/sample/test/balance_sample.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/test/balance_sample.sql_in 
b/src/ports/postgres/modules/sample/test/balance_sample.sql_in
new file mode 100644
index 0000000..2249fd6
--- /dev/null
+++ b/src/ports/postgres/modules/sample/test/balance_sample.sql_in
@@ -0,0 +1,157 @@
+/* ----------------------------------------------------------------------- 
*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ----------------------------------------------------------------------- 
*/
+
+DROP TABLE IF EXISTS "TEST_s" CASCADE;
+
+CREATE TABLE "TEST_s"(
+    id1 INTEGER,
+    id2 INTEGER,
+    gr1 INTEGER,
+    gr2 INTEGER
+);
+
+INSERT INTO "TEST_s" VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(19,0,1,1),
+(29,0,1,1),
+(39,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2),
+(10,10,5,5),
+(50,50,5,5),
+(88,88,5,5),
+(40,40,5,6),
+(50,50,5,6),
+(60,60,5,6),
+(70,70,5,6),
+(10,10,6,6),
+(60,60,6,6),
+(30,30,6,6),
+(40,40,6,6),
+(50,50,6,6),
+(60,60,6,6),
+(70,70,6,6),
+(50,50,4,2),
+(60,60,4,2),
+(70,70,4,2),
+(50,50,3,2),
+(60,60,3,2),
+(70,70,3,2),
+(500,50,NULL,2),
+(600,60,NULL,2),
+(700,70,NULL,2)
+;
+
+-- SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1;
+--  gr1 | c
+-- -----+----
+--    4 |  3
+--    1 | 18
+--    5 |  7
+--    3 |  3
+--    6 |  7
+--    2 |  7
+-- NULL | 3
+-- (6 rows)
+
+SELECT gr1, count(*) AS c FROM "TEST_s" GROUP BY gr1;
+
+--- Test for random undersampling without replacement
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', ' undersample', NULL, NULL, 
FALSE);
+SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1;
+SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling gr1') 
FROM
+        (SELECT gr1, count(*) AS c FROM out_s GROUP BY gr1) AS foo WHERE foo.c 
!= 3;
+
+-- --- Test for random undersampling with replacement
+DROP TABLE IF EXISTS out_sr2;
+SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample ', NULL, 
NULL, TRUE, TRUE);
+SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1;
+-- SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1;
+SELECT assert(count(*) = 0, 'Wrong number of samples on undersampling with 
replacement on gr1') FROM
+        (SELECT gr1, count(*) AS c FROM out_sr2 GROUP BY gr1) AS foo WHERE 
foo.c != 3;
+
+-- --- Test for random oversampling
+DROP TABLE IF EXISTS out_or3;
+SELECT balance_sample('"TEST_s"', 'out_or3', 'gr1', ' oVEr   ', NULL, NULL);
+SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1;
+SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling') FROM
+        (SELECT gr1, count(*) AS c FROM out_or3 GROUP BY gr1) AS foo WHERE 
foo.c != 18;
+
+--- UNIFORM sampling
+DROP TABLE IF EXISTS out_cd2;
+SELECT balance_sample('"TEST_s"', 'out_cd2', 'gr1', 'Uniform', NULL, NULL);
+SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1;
+SELECT assert(count(*) = 0, 'Wrong number of samples on uniform sampling for 
gr1') FROM
+        (SELECT gr1, count(*) AS c FROM out_cd2 GROUP BY gr1) AS foo WHERE 
foo.c != 8;
+
+--- Default sampling should be uniform
+DROP TABLE IF EXISTS out_cd3;
+SELECT balance_sample('"TEST_s"', 'out_cd3', 'gr1', NULL, 100, NULL, NULL, 
TRUE);
+SELECT gr1, count(*) AS c FROM out_cd3 GROUP BY gr1;
+SELECT assert(count(*) = 0, 'Wrong number of samples on uniform sampling for 
gr1') FROM
+        (SELECT gr1, count(*) AS c FROM out_cd3 GROUP BY gr1) AS foo WHERE 
foo.c != 15;
+
+--- Only one class size is specified
+DROP TABLE IF EXISTS out_cd4;
+SELECT balance_sample('"TEST_s"', 'out_cd4', 'gr1', ' 2 =10', NULL, NULL, 
TRUE);
+SELECT assert(count(*) = 48, 'Wrong number of samples on oversampling with 
comma-delimited list') from out_cd4;
+SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd4 where gr1 = 2) AS foo WHERE foo.c 
!= 10;
+SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd4 where gr1 = 3) AS foo WHERE foo.c 
!= 3;
+SELECT assert(count(*) = 0, 'Wrong number of samples on oversampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd4 where gr1 = 1) AS foo WHERE foo.c 
!= 18;
+
+--- Multiple class sizes with comma delimited string
+DROP TABLE IF EXISTS out_cd5;
+SELECT balance_sample('"TEST_s"', 'out_cd5', 'gr1', '2= 10, 3=6, 1 = 10', 100, 
NULL);
+select gr1, count(*) from out_cd5 group by gr1;
+SELECT assert(count(*) >= 100, 'Wrong number of samples on sampling with 
comma-delimited list') from out_cd5;
+SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd5 where gr1 = 2) AS foo WHERE foo.c 
!= 10;
+SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd5 where gr1 = 4) AS foo WHERE foo.c 
!= 25;
+SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd5 where gr1 = 1) AS foo WHERE foo.c 
!= 10;
+SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd5 where gr1 = 5) AS foo WHERE foo.c 
!= 25;
+SELECT assert(count(*) = 0, 'Wrong number of samples on sampling with 
comma-delimited list') from
+        (SELECT count(*) AS c FROM out_cd5 where gr1 = 3) AS foo WHERE foo.c 
!= 6;

Reply via email to