Github user kaknikhil commented on a diff in the pull request: https://github.com/apache/madlib/pull/243#discussion_r175894372 --- Diff: src/ports/postgres/modules/convex/mlp_igd.py_in --- @@ -1457,3 +1660,85 @@ def mlp_predict_help(schema_madlib, message): return """ No such option. Use "SELECT {schema_madlib}.mlp_predict()" for help. """.format(**args) + + +def check_if_minibatch_enabled(source_table, independent_varname): + """ + Function to validate if the source_table is converted to a format that + can be used for mini-batching. It checks for the dimensionalities of + the independent variable to determine the same. + """ + query = """ + SELECT array_upper({0}, 1) AS n_x, + array_upper({0}, 2) AS n_y, + array_upper({0}, 3) AS n_z + FROM {1} + LIMIT 1 + """.format(independent_varname, source_table) + result = plpy.execute(query) + + if not result: + plpy.error("MLP: Input table could be empty.") + + has_x_dim, has_y_dim, has_z_dim = [bool(result[0][i]) + for i in ('n_x', 'n_y', 'n_z')] + if not has_x_dim: + plpy.error("MLP: {0} is empty.".format(independent_varname)) + + # error out if >2d matrix + if has_z_dim: + plpy.error("MLP: Input table is not in the right format.") + return has_y_dim + + +class MLPPreProcessor: + """ + This class consumes and validates the pre-processed source table used for + MLP mini-batch. This also populates values from the pre-processed summary + table which is used by MLP mini-batch + + """ + # summary table columns names + DEPENDENT_VARNAME = "dependent_varname" + INDEPENDENT_VARNAME = "independent_varname" + GROUPING_COL = "grouping_cols" + CLASS_VALUES = "class_values" + MODEL_TYPE_CLASSIFICATION = "classification" + MODEL_TYPE_REGRESSION = "regression" + + def __init__(self, source_table): + self.source_table = source_table + self.preprocessed_summary_dict = None + self.summary_table = add_postfix(self.source_table, "_summary") + self.std_table = add_postfix(self.source_table, "_standardization") + + self._validate_and_set_preprocessed_summary() + + def _validate_and_set_preprocessed_summary(self): + input_tbl_valid(self.source_table, 'MLP') + + if not table_exists(self.summary_table) or not table_exists(self.std_table): + plpy.error("Tables {0} and/or {1} do not exist. These tables are" + " needed for using minibatch during training.".format( + self.summary_table, + self.std_table)) + + query = "SELECT * FROM {0}".format(self.summary_table) + summary_table_columns = plpy.execute(query) + if not summary_table_columns or len(summary_table_columns) == 0: + plpy.error("No columns in table {0}.".format(self.summary_table)) + else: + summary_table_columns = summary_table_columns[0] + + required_columns = (self.DEPENDENT_VARNAME, self.INDEPENDENT_VARNAME, + self.CLASS_VALUES) + if set(required_columns) <= set(summary_table_columns): + self.preprocessed_summary_dict = summary_table_columns + else: + plpy.error("Expected columns ({0}, {1} and/or {2}) not present in" --- End diff -- We can use the `required_columns` to format the error message so that we don't have to repeat the column names. Something like ``` plpy.error("One or more of the expected columns {0} not present in {1}".format(required_columns, self.summary_table)) ```
---