Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/241#discussion_r175593796 --- Diff: src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in --- @@ -0,0 +1,559 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +@file minibatch_preprocessing.py_in + +""" +from math import ceil +import plpy + +from utilities import add_postfix +from utilities import _assert +from utilities import get_seg_number +from utilities import is_platform_pg +from utilities import is_psql_numeric_type +from utilities import is_string_formatted_as_array_expression +from utilities import py_list_to_sql_string +from utilities import split_quoted_delimited_str +from utilities import _string_to_array +from utilities import validate_module_input_params +from mean_std_dev_calculator import MeanStdDevCalculator +from validate_args import get_expr_type +from validate_args import output_tbl_valid +from validate_args import _tbl_dimension_rownum + +m4_changequote(`<!', `!>') + +# These are readonly variables, do not modify +MINIBATCH_OUTPUT_DEPENDENT_COLNAME = "dependent_varname" +MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname" + +class MiniBatchPreProcessor: + """ + This class is responsible for executing the main logic of mini batch + preprocessing, which packs multiple rows of selected columns from the + source table into one row based on the buffer size + """ + def __init__(self, schema_madlib, source_table, output_table, + dependent_varname, independent_varname, buffer_size, **kwargs): + self.schema_madlib = schema_madlib + self.source_table = source_table + self.output_table = output_table + self.dependent_varname = dependent_varname + self.independent_varname = independent_varname + self.buffer_size = buffer_size + + self.module_name = "minibatch_preprocessor" + self.output_standardization_table = add_postfix(self.output_table, + "_standardization") + self.output_summary_table = add_postfix(self.output_table, "_summary") + self._validate_minibatch_preprocessor_params() + + def minibatch_preprocessor(self): + # Get array expressions for both dep and indep variables from the + # MiniBatchQueryFormatter class + dependent_var_dbtype = get_expr_type(self.dependent_varname, + self.source_table) + qry_formatter = MiniBatchQueryFormatter(self.source_table) + dep_var_array_str, dep_var_classes_str = qry_formatter.\ + get_dep_var_array_and_classes(self.dependent_varname, + dependent_var_dbtype) + indep_var_array_str = qry_formatter.get_indep_var_array_str( + self.independent_varname) + + standardizer = MiniBatchStandardizer(self.schema_madlib, + self.source_table, + dep_var_array_str, + indep_var_array_str, + self.output_standardization_table) + standardize_query = standardizer.get_query_for_standardizing() + + num_rows_processed, num_missing_rows_skipped = self.\ + _get_skipped_rows_processed_count( + dep_var_array_str, + indep_var_array_str) + calculated_buffer_size = MiniBatchBufferSizeCalculator.\ + calculate_default_buffer_size( + self.buffer_size, + num_rows_processed, + standardizer.independent_var_dimension) + """ + This query does the following: + 1. Standardize the independent variables in the input table + (see MiniBatchStandardizer for more details) + 2. Filter out rows with null values either in dependent/independent + variables + 3. Converts the input dependent/independent variables into arrays + (see MiniBatchQueryFormatter for more details) + 4. Based on the buffer size, pack the dependent/independent arrays into + matrices + + Notes + 1. we are ignoring null in x because + a. matrix_agg does not support null + b. __utils_normalize_data returns null if any element of the array + contains NULL + 2. Please keep the null checking where clause of this query in sync with + the query in _get_skipped_rows_processed_count. We are doing this null + check in two places to prevent another pass of the entire dataset. + """ + + # This ID is the unique row id that get assigned to each row after preprocessing + unique_row_id = "__id__" + sql = """ + CREATE TABLE {output_table} AS + SELECT {row_id}, + {schema_madlib}.matrix_agg({dep_colname}) as {dep_colname}, + {schema_madlib}.matrix_agg({ind_colname}) as {ind_colname} + FROM ( + SELECT (row_number() OVER (ORDER BY random()) - 1) / {buffer_size} + as {row_id}, * FROM + ( + {standardize_query} + ) sub_query_1 + WHERE NOT {schema_madlib}.array_contains_null({dep_colname}) + AND NOT {schema_madlib}.array_contains_null({ind_colname}) + ) sub_query_2 + GROUP BY {row_id} + {distributed_by_clause} + """.format( + schema_madlib=self.schema_madlib, + source_table=self.source_table, + output_table=self.output_table, + dependent_varname=self.dependent_varname, + independent_varname=self.independent_varname, + buffer_size = calculated_buffer_size, + dep_colname=MINIBATCH_OUTPUT_DEPENDENT_COLNAME, + ind_colname=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME, + row_id = unique_row_id, + distributed_by_clause = '' if is_platform_pg() else 'DISTRIBUTED RANDOMLY', + **locals()) + plpy.execute(sql) + + + standardizer.create_output_standardization_table() + MiniBatchSummarizer.create_output_summary_table( + self.source_table, + self.output_table, + self.dependent_varname, + self.independent_varname, + calculated_buffer_size, + dep_var_classes_str, + num_rows_processed, + num_missing_rows_skipped, + self.output_summary_table) + + def _validate_minibatch_preprocessor_params(self): + # Test if the independent variable can be typecasted to a double precision + # array and let postgres validate the expression + + # Note that this will not fail for 2d arrays but the standardizer will + # fail because utils_normalize_data will throw an error + typecasted_ind_varname = "{0}::double precision[]".format( + self.independent_varname) + validate_module_input_params(self.source_table, self.output_table, + typecasted_ind_varname, + self.dependent_varname, self.module_name) + + self._validate_other_output_tables() --- End diff -- This is linked to the last comment on this PR. We could use an optional param in `validate_module_input_params` to validate other output tables too (a list of suffixes such as `['_summary', '_standardization']`).
---