Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/241#discussion_r175585050
  
    --- Diff: 
src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in ---
    @@ -0,0 +1,559 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file except in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +
    +"""
    +@file minibatch_preprocessing.py_in
    +
    +"""
    +from math import ceil
    +import plpy
    +
    +from utilities import add_postfix
    +from utilities import _assert
    +from utilities import get_seg_number
    +from utilities import is_platform_pg
    +from utilities import is_psql_numeric_type
    +from utilities import is_string_formatted_as_array_expression
    +from utilities import py_list_to_sql_string
    +from utilities import split_quoted_delimited_str
    +from utilities import _string_to_array
    +from utilities import validate_module_input_params
    +from mean_std_dev_calculator import MeanStdDevCalculator
    +from validate_args import get_expr_type
    +from validate_args import output_tbl_valid
    +from validate_args import _tbl_dimension_rownum
    +
    +m4_changequote(`<!', `!>')
    +
    +# These are readonly variables, do not modify
    +MINIBATCH_OUTPUT_DEPENDENT_COLNAME = "dependent_varname"
    +MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname"
    +
    +class MiniBatchPreProcessor:
    +    """
    +    This class is responsible for executing the main logic of mini batch
    +    preprocessing, which packs multiple rows of selected columns from the
    +    source table into one row based on the buffer size
    +    """
    +    def __init__(self, schema_madlib, source_table, output_table,
    +                  dependent_varname, independent_varname, buffer_size, 
**kwargs):
    +        self.schema_madlib = schema_madlib
    +        self.source_table = source_table
    +        self.output_table = output_table
    +        self.dependent_varname = dependent_varname
    +        self.independent_varname = independent_varname
    +        self.buffer_size = buffer_size
    +
    +        self.module_name = "minibatch_preprocessor"
    +        self.output_standardization_table = add_postfix(self.output_table,
    +                                                   "_standardization")
    +        self.output_summary_table = add_postfix(self.output_table, 
"_summary")
    +        self._validate_minibatch_preprocessor_params()
    +
    +    def minibatch_preprocessor(self):
    +        # Get array expressions for both dep and indep variables from the
    +        # MiniBatchQueryFormatter class
    +        dependent_var_dbtype = get_expr_type(self.dependent_varname,
    +                                             self.source_table)
    +        qry_formatter = MiniBatchQueryFormatter(self.source_table)
    +        dep_var_array_str, dep_var_classes_str = qry_formatter.\
    +            get_dep_var_array_and_classes(self.dependent_varname,
    +                                          dependent_var_dbtype)
    +        indep_var_array_str = qry_formatter.get_indep_var_array_str(
    +                                              self.independent_varname)
    +
    +        standardizer = MiniBatchStandardizer(self.schema_madlib,
    +                                             self.source_table,
    +                                             dep_var_array_str,
    +                                             indep_var_array_str,
    +                                             
self.output_standardization_table)
    +        standardize_query = standardizer.get_query_for_standardizing()
    +
    +        num_rows_processed, num_missing_rows_skipped = self.\
    +                                                
_get_skipped_rows_processed_count(
    +                                                dep_var_array_str,
    +                                                indep_var_array_str)
    +        calculated_buffer_size = MiniBatchBufferSizeCalculator.\
    +                                         calculate_default_buffer_size(
    +                                         self.buffer_size,
    +                                         num_rows_processed,
    +                                         
standardizer.independent_var_dimension)
    +        """
    +        This query does the following:
    +        1. Standardize the independent variables in the input table
    +           (see MiniBatchStandardizer for more details)
    +        2. Filter out rows with null values either in dependent/independent
    +           variables
    +        3. Converts the input dependent/independent variables into arrays
    +          (see MiniBatchQueryFormatter for more details)
    +        4. Based on the buffer size, pack the dependent/independent arrays 
into
    +           matrices
    +
    +        Notes
    +        1. we are ignoring null in x because
    +             a. matrix_agg does not support null
    +             b. __utils_normalize_data returns null if any element of the 
array
    +                contains NULL
    +        2. Please keep the null checking where clause of this query in 
sync with
    +        the query in _get_skipped_rows_processed_count. We are doing this 
null
    +        check in two places to prevent another pass of the entire dataset.
    +        """
    +
    +        # This ID is the unique row id that get assigned to each row after 
preprocessing
    +        unique_row_id = "__id__"
    +        sql = """
    +            CREATE TABLE {output_table} AS
    +            SELECT {row_id},
    +                   {schema_madlib}.matrix_agg({dep_colname}) as 
{dep_colname},
    +                   {schema_madlib}.matrix_agg({ind_colname}) as 
{ind_colname}
    +            FROM (
    +                SELECT (row_number() OVER (ORDER BY random()) - 1) / 
{buffer_size}
    +                            as {row_id}, * FROM
    +                (
    +                    {standardize_query}
    +                 ) sub_query_1
    +                 WHERE NOT 
{schema_madlib}.array_contains_null({dep_colname})
    +                 AND NOT {schema_madlib}.array_contains_null({ind_colname})
    +            ) sub_query_2
    +            GROUP BY {row_id}
    +            {distributed_by_clause}
    +            """.format(
    +            schema_madlib=self.schema_madlib,
    +            source_table=self.source_table,
    +            output_table=self.output_table,
    +            dependent_varname=self.dependent_varname,
    +            independent_varname=self.independent_varname,
    +            buffer_size = calculated_buffer_size,
    +            dep_colname=MINIBATCH_OUTPUT_DEPENDENT_COLNAME,
    +            ind_colname=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME,
    +            row_id = unique_row_id,
    +            distributed_by_clause = '' if is_platform_pg() else 
'DISTRIBUTED RANDOMLY',
    +            **locals())
    +        plpy.execute(sql)
    +
    +
    +        standardizer.create_output_standardization_table()
    +        MiniBatchSummarizer.create_output_summary_table(
    +            self.source_table,
    +            self.output_table,
    +            self.dependent_varname,
    +            self.independent_varname,
    +            calculated_buffer_size,
    +            dep_var_classes_str,
    +            num_rows_processed,
    +            num_missing_rows_skipped,
    +            self.output_summary_table)
    +
    +    def _validate_minibatch_preprocessor_params(self):
    +        # Test if the independent variable can be typecasted to a double 
precision
    +        # array and let postgres validate the expression
    +
    +        # Note that this will not fail for 2d arrays but the standardizer 
will
    +        # fail because utils_normalize_data will throw an error
    +        typecasted_ind_varname = "{0}::double precision[]".format(
    +                                                    
self.independent_varname)
    +        validate_module_input_params(self.source_table, self.output_table,
    +                                     typecasted_ind_varname,
    +                                     self.dependent_varname, 
self.module_name)
    +
    +        self._validate_other_output_tables()
    +
    +        num_of_dependent_cols = 
split_quoted_delimited_str(self.dependent_varname)
    +
    +        _assert(len(num_of_dependent_cols) == 1,
    +                "Invalid dependent_varname: only one column name is 
allowed "
    +                "as input.")
    +
    +        if self.buffer_size is not None:
    +            _assert(self.buffer_size > 0,
    +                """minibatch_preprocessor: The buffer size has to be a 
positive
    +                 integer or NULL.""")
    +
    +    def _validate_other_output_tables(self):
    +        """
    +        Validate that standardization and summary table do not exist.
    +        :return:
    +        """
    +        output_tbl_valid(self.output_standardization_table, 
self.module_name)
    +        output_tbl_valid(self.output_summary_table, self.module_name)
    +
    +    def _get_skipped_rows_processed_count(self, dep_var_array, 
indep_var_array):
    +        # Note: Keep the null checking where clause of this query in sync 
with the
    +        # main create output table query.
    +        query = """
    +                SELECT COUNT(*) AS source_table_row_count,
    +                sum(CASE WHEN
    +                NOT {schema_madlib}.array_contains_null({dep_var_array})
    +                AND NOT 
{schema_madlib}.array_contains_null({indep_var_array})
    +                THEN 1 ELSE 0 END) AS num_rows_processed
    +                FROM {source_table}
    +        """.format(
    +        schema_madlib = self.schema_madlib,
    +        source_table = self.source_table,
    +        dep_var_array = dep_var_array,
    +        indep_var_array = indep_var_array)
    +        result = plpy.execute(query)
    +
    +        source_table_row_count = result[0]['source_table_row_count']
    +        num_rows_processed = result[0]['num_rows_processed']
    +        if not source_table_row_count or not num_rows_processed:
    +            plpy.error("Error while getting the row count of the source 
table"
    +                       "{0}".format(self.source_table))
    +        num_missing_rows_skipped = source_table_row_count - 
num_rows_processed
    +
    +        return num_rows_processed, num_missing_rows_skipped
    +
    +class MiniBatchQueryFormatter:
    +    """
    +    This class is responsible for formatting the independent and dependent
    +    variables into arrays so that they can be matrix agged by the 
preprocessor
    +    class.
    +    """
    +    def __init__(self, source_table):
    +        self.source_table = source_table
    +
    +    def get_dep_var_array_and_classes(self, dependent_varname, 
dependent_var_dbtype):
    +        """
    +        This function returns a tuple of
    +        1. A string with transformed dependent varname depending on it's 
type
    +        2. All the distinct dependent class levels encoded as a string
    +
    +        If dep_type == numeric , do not encode
    +                1. dependent_varname = rings
    +                    transformed_value = ARRAY[[rings1], [rings2], []]
    +                    class_level_str = ARRAY[rings = 'rings1',
    +                                            rings = 'rings2']::integer[]
    +                2. dependent_varname = ARRAY[a, b, c]
    +                    transformed_value = ARRAY[[a1, b1, c1], [a2, b2, c2], 
[]]
    +                    class_level_str = 'NULL::TEXT'
    +        else if dep_type in ("text", "boolean"), encode:
    +                3. dependent_varname = rings (encoding)
    +                    transformed_value = ARRAY[[rings1=1, rings1=2], 
[rings2=1,
    +                                                rings2=2], []]
    +                    class_level_str = 'NULL::TEXT'
    +
    +        :param dependent_varname:
    +        :param dependent_var_dbtype:
    +        :return:
    +        """
    +        """
    +        """
    +        dep_var_class_value_str = 'NULL::TEXT'
    +        if dependent_var_dbtype in ("text", "boolean"):
    +            # for encoding, and since boolean can also be a logical 
expression,
    +            # there is a () for {dependent_varname} to make the query work
    +            dep_level_sql = """
    +            SELECT DISTINCT ({dependent_varname}) AS class
    +            FROM {source_table} where ({dependent_varname}) is NOT NULL
    +            """.format(dependent_varname=dependent_varname,
    +                       source_table=self.source_table)
    +            dep_levels = plpy.execute(dep_level_sql)
    +
    +            # this is string sorting
    +            dep_var_classes = sorted(
    +                ["{0}".format(l["class"]) for l in dep_levels])
    +
    +            dep_var_array_str = 
self._get_one_hot_encoded_str(dependent_varname,
    +                                                              
dep_var_classes)
    +            dep_var_class_value_str = 
py_list_to_sql_string(dep_var_classes,
    +                                         array_type=dependent_var_dbtype)
    +
    +        elif "[]" in dependent_var_dbtype:
    +            dep_var_array_str = dependent_varname
    +
    +        elif is_psql_numeric_type(dependent_var_dbtype):
    +            dep_var_array_str = 'ARRAY[{0}]'.format(dependent_varname)
    +
    +        else:
    +            plpy.error("""Invalid dependent variable type. It should be 
text,
    +                boolean, numeric, or an array.""")
    +
    +        return dep_var_array_str, dep_var_class_value_str
    +
    +    def _get_one_hot_encoded_str(self, var_name, var_classes):
    +        one_hot_list = []
    +        for c in var_classes:
    +            one_hot_list.append("({0}) = '{1}'".format(var_name, c))
    +
    +        return 'ARRAY[{0}]::integer[]'.format(','.join(one_hot_list))
    +
    +    def get_indep_var_array_str(self, independent_varname):
    +        """
    +        we assume that all the independent features are either numeric or
    +        already encoded by the user.
    +        Supported formats
    +        1. ‘ARRAY[x1,x2,x3]’ , where x1,x2,x3 are columns in source 
table with
    +        scalar values
    +        2. ‘x1’, where x1 is a single column in source table, with 
value as an
    +        array, like ARRAY[1,2,3] or {1,2,3}
    +
    +        we don't deal with a mixture of scalar and array independent 
variables
    +        """
    +        typecasted_ind_varname = "{0}::double precision[]".format(
    +                                                            
independent_varname)
    +        return typecasted_ind_varname
    +
    +class MiniBatchStandardizer:
    +    """
    +    This class is responsible for
    +    1. Calculating the mean and std dev for independent variables
    +    2. Format the query to standardize the input table based on the
    +       calculated mean/std dev
    +    3. Creating the output standardization table
    +    """
    +    def __init__(self, schema_madlib, source_table, dep_var_array_str,
    +                 indep_var_array_str, output_standardization_table):
    +        self.schema_madlib = schema_madlib
    +        self.source_table = source_table
    +        self.dep_var_array_str = dep_var_array_str
    +        self.indep_var_array_str = indep_var_array_str
    +        self.output_standardization_table = output_standardization_table
    +
    +        self.x_mean_str = None
    +        self.x_std_dev_str = None
    +        self.source_table_row_count = 0
    +        self.grouping_cols = "NULL"
    +        self.independent_var_dimension = None
    +        self._calculate_mean_and_std_dev_str()
    +
    +    def _calculate_mean_and_std_dev_str(self):
    +        self.independent_var_dimension, _ = _tbl_dimension_rownum(
    +                                                        self.schema_madlib,
    +                                                        self.source_table,
    +                                                        
self.indep_var_array_str,
    +                                                        
skip_row_count=True)
    +
    +        calculator = MeanStdDevCalculator(self.schema_madlib,
    +                                          self.source_table,
    +                                          self.indep_var_array_str,
    +                                          self.independent_var_dimension)
    +
    +        self.x_mean_str, self.x_std_dev_str = calculator.\
    +                                              
get_mean_and_std_dev_for_ind_var()
    +
    +        if not self.x_mean_str or not self.x_std_dev_str:
    +            plpy.error("mean/stddev for the independent variable"
    +                       "cannot be null")
    +
    +    def get_query_for_standardizing(self):
    +        query="""
    +        SELECT
    +        {dep_var_array_str} as {dep_colname},
    +        {schema_madlib}.utils_normalize_data
    +        (
    +            {indep_var_array_str},'{x_mean_str}'::double precision[],
    +            '{x_std_dev_str}'::double precision[]
    +        ) as {ind_colname}
    +        FROM {source_table}
    +        """.format(
    +            source_table = self.source_table,
    +            schema_madlib = self.schema_madlib,
    +            dep_var_array_str = self.dep_var_array_str,
    +            indep_var_array_str = self.indep_var_array_str,
    +            dep_colname = MINIBATCH_OUTPUT_DEPENDENT_COLNAME,
    +            ind_colname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME,
    +            x_mean_str = self.x_mean_str,
    +            x_std_dev_str = self.x_std_dev_str)
    +        return query
    +
    +    def create_output_standardization_table(self):
    +        query = """
    +        CREATE TABLE {output_standardization_table} AS
    +        select {grouping_cols}::TEXT AS grouping_cols,
    +        '{x_mean_str}'::double precision[] AS mean,
    +        '{x_std_dev_str}'::double precision[] AS std
    +        """.format(
    +        output_standardization_table = self.output_standardization_table,
    +        grouping_cols = self.grouping_cols,
    +        x_mean_str = self.x_mean_str,
    +        x_std_dev_str = self.x_std_dev_str)
    +        plpy.execute(query)
    +
    +class MiniBatchSummarizer:
    +    @staticmethod
    +    def create_output_summary_table(source_table, output_table,
    +                                    dep_var_array_str, indep_var_array_str,
    +                                    buffer_size, class_values, 
num_rows_processed,
    +                                    num_missing_rows_skipped, 
output_summary_table):
    +        query = """
    +            CREATE TABLE {output_summary_table} AS
    +            SELECT '{source_table}'::TEXT AS source_table,
    +            '{output_table}'::TEXT AS output_table,
    +            '{dependent_varname}'::TEXT AS dependent_varname,
    +            '{independent_varname}'::TEXT AS independent_varname,
    +            {buffer_size} AS buffer_size,
    +            {class_values} AS class_values,
    +            {num_rows_processed} AS num_rows_processed,
    +            {num_missing_rows_skipped} AS num_missing_rows_skipped,
    +            {grouping_cols}::TEXT AS grouping_cols
    +        """.format(output_summary_table = output_summary_table,
    +                   source_table = source_table,
    +                   output_table = output_table,
    +                   dependent_varname = dep_var_array_str,
    +                   independent_varname = indep_var_array_str,
    +                   buffer_size = buffer_size,
    +                   class_values = class_values,
    +                   num_rows_processed = num_rows_processed,
    +                   num_missing_rows_skipped = num_missing_rows_skipped,
    +                   grouping_cols = "NULL")
    +        plpy.execute(query)
    +
    +class MiniBatchBufferSizeCalculator:
    +    """
    +    This class is responsible for calculating the buffer size.
    +    This is a work in progress, final formula might change.
    +    """
    +    @staticmethod
    +    def calculate_default_buffer_size(buffer_size,
    +                                      num_rows_processed, 
independent_var_dimension):
    +        if buffer_size is not None:
    +            return buffer_size
    +        num_of_segments = get_seg_number()
    +
    +        default_buffer_size = min(75000000.0/independent_var_dimension,
    +                                    
float(num_rows_processed)/num_of_segments)
    +        return ceil(default_buffer_size)
    +
    +class MiniBatchDocumentation:
    +    @staticmethod
    +    def minibatch_preprocessor_help(schema_madlib, message):
    +        method = "minibatch_preprocessor"
    +        summary = """
    +        ----------------------------------------------------------------
    +                            SUMMARY
    +        ----------------------------------------------------------------
    +        MiniBatch Preprocessor is a utility function to pre process the 
input
    +        data for use with models that support mini-batching as an 
optimization
    +
    +        #TODO add more here
    --- End diff --
    
    Is the comment still required?


---

Reply via email to