This is an automated email from the ASF dual-hosted git repository. jingyimei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 16814687fa328dd7b9db79d8bd629af8529879b6 Author: Jingyi Mei <[email protected]> AuthorDate: Wed Mar 20 16:55:08 2019 -0700 DL: Add unit test and refactor Closes: #357 Co-authored-by: Ekta Khanna <[email protected]> --- .../utilities/minibatch_preprocessing.py_in | 39 +++--- .../unit_tests/test_minibatch_preprocessing.py_in | 140 +++++++++++++++++++++ .../postgres/modules/utilities/validate_args.py_in | 2 +- 3 files changed, 164 insertions(+), 17 deletions(-) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index f3dc462..8bfa978 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -55,7 +55,7 @@ MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname" # These are readonly variables, do not modify #MADLIB-1300 Adding these variables for DL only at this time. # For release 2.0 These will be removed and above variables can -# used for regular and DL minibatch. +# be used for regular and DL minibatch. MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL = "dependent_var" MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL = "independent_var" @@ -134,8 +134,6 @@ class MiniBatchPreProcessor: [self.output_summary_table, self.output_standardization_table]) - - num_of_dependent_cols = split_quoted_delimited_str(self.dependent_varname) valid_types = NUMERIC | TEXT | BOOLEAN _assert(is_valid_psql_type(self.dependent_vartype, @@ -355,9 +353,14 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self.dependent_varname = dependent_varname self.independent_varname = independent_varname self.buffer_size = buffer_size - self.normalizing_const = normalizing_const if normalizing_const else 1.0 + self.normalizing_const = normalizing_const if normalizing_const is not None else 1.0 self.module_name = "minibatch_preprocessor_DL" self.output_summary_table = add_postfix(self.output_table, "_summary") + self.independent_vartype = get_expr_type(self.independent_varname, + self.source_table) + self.dependent_vartype = get_expr_type(self.dependent_varname, + self.source_table) + self._validate_args() self.num_of_buffers = self._get_num_buffers() self.to_one_hot_encode = True @@ -367,14 +370,12 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self.dependent_levels = get_distinct_col_levels( self.source_table, self.dependent_varname, self.dependent_vartype) - def minibatch_preprocessor_dl(self): # Create a temp table that has independent var normalized. norm_tbl = unique_string(desp='normalized') - dependent_varname_with_offset = self.dependent_varname dep_var_array_expr = self.get_dep_var_array_expr() - # Assuming the input NUMERIC[] is already on_hot_encoded, so casting to INTEGER[] + # Assuming the input NUMERIC[] is already one_hot_encoded, so casting to INTEGER[] if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): dep_var_array_expr = dep_var_array_expr + '::INTEGER[]' scalar_mult_sql = """ @@ -383,7 +384,7 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): {self.independent_varname}::REAL[], (1/{self.normalizing_const})::REAL) AS x_norm, {dep_var_array_expr} AS y, row_number() over() AS row_id - FROM {self.source_table} + FROM {self.source_table} order by random() """.format(**locals()) plpy.execute(scalar_mult_sql) # Create the mini-batched output table @@ -437,18 +438,22 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self.source_table, self.output_table, self.independent_varname, self.dependent_varname, self.module_name, None, [self.output_summary_table]) - self.independent_vartype = get_expr_type( - self.independent_varname, self.source_table) + num_of_independent_cols = split_quoted_delimited_str(self.independent_varname) + _assert(len(num_of_independent_cols) == 1, + "Invalid independent_varname: only one column name is allowed " + "as input.") _assert(is_valid_psql_type(self.independent_vartype, NUMERIC | ONLY_ARRAY), "Invalid independent variable type, should be an array of " "one of {0}".format(','.join(NUMERIC))) - self.dependent_vartype = get_expr_type( - self.dependent_varname, self.source_table) # The denpendent variable needs to be either: # 1. NUMERIC, TEXT OR BOOLEAN, which we always one-hot encode # 2. NUMERIC ARRAY, which we assume it is already one-hot encoded, and we # just cast it the INTEGER ARRAY + num_of_dependent_cols = split_quoted_delimited_str(self.dependent_varname) + _assert(len(num_of_dependent_cols) == 1, + "Invalid dependent_varname: only one column name is allowed " + "as input.") _assert((is_valid_psql_type(self.dependent_vartype, NUMERIC | TEXT | BOOLEAN) or is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY)), """Invalid dependent variable type, should be one of the type in this list: @@ -457,6 +462,9 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): _assert(self.buffer_size > 0, "minibatch_preprocessor_dl: The buffer size has to be a " "positive integer or NULL.") + _assert(self.normalizing_const > 0, + "minibatch_preprocessor_dl: The normalizing constant has to be a " + "positive integer or NULL.") def _get_num_buffers(self): num_rows_in_tbl = plpy.execute(""" @@ -754,9 +762,6 @@ class MiniBatchDocumentation: normalizing_const -- DOUBLE PRECISON. Default 1.0. The normalizing constant to use for standardizing arrays in independent_varname. - one_hot_encode_int_dep_var -- BOOLEAN. Default FALSE. Flag to one-hot - encode dependent variables that are - scalar integers ); @@ -783,6 +788,9 @@ class MiniBatchDocumentation: table. dependent_vartype -- Type of the dependent variable from the original table. + class_values -- Class values of the dependent variable + (‘NULL’(as TEXT type) for non + categorical vars). buffer_size -- Buffer size used in preprocessing step. normalizing_const -- Normalizing constant used for standardizing arrays in independent_varname. @@ -790,7 +798,6 @@ class MiniBatchDocumentation: --------------------------------------------------------------------------- """.format(**locals()) - if not message: return summary elif message.lower() in ('usage', 'help', '?'): diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in index b262964..2fdf082 100644 --- a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in @@ -322,7 +322,147 @@ class MiniBatchBufferSizeCalculatorTestCase(unittest.TestCase): class AnyStringWith(str): def __eq__(self, other): return self in other +class MiniBatchPreProcessingDLTestCase(unittest.TestCase): + def setUp(self): + self.plpy_mock = Mock(spec='error') + patches = { + 'plpy': plpy, + 'utilities.mean_std_dev_calculator': Mock(), + } + # we need to use MagicMock() instead of Mock() for the plpy.execute mock + # to be able to iterate on the return value + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + + self.default_schema_madlib = "madlib" + self.default_source_table = "source" + self.default_output_table = "output" + self.default_dep_var = "depvar" + self.default_ind_var = "indvar" + self.default_buffer_size = 5 + #self.dependent_vartype = "integer[]" + #self.independent_vartype = "integer[]" + + import utilities.minibatch_preprocessing + self.module = utilities.minibatch_preprocessing + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + self.module.validate_module_input_params = Mock() + self.module.get_distinct_col_levels = Mock(return_value = [0,22,100]) + self.subject = self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size) + + def tearDown(self): + self.module_patcher.stop() + + def test_minibatch_preprocessor_dl_executes_query(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + preprocessor_obj = self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + "input", + "out", + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size) + preprocessor_obj.minibatch_preprocessor_dl() + + def test_minibatch_preprocessor_null_buffer_size_executes_query(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + preprocessor_obj = self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + "input", + "out", + self.default_dep_var, + self.default_ind_var, + None) + self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = Mock(return_value = 5) + preprocessor_obj.minibatch_preprocessor_dl() + + def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + "y1,y2", + self.default_ind_var, + self.default_buffer_size) + + def test_minibatch_preprocessor_multiple_indep_var_raises_exception(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + "x1,x2", + self.default_buffer_size) + + def test_minibatch_preprocessor_buffer_size_zero_fails(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + 0) + + def test_minibatch_preprocessor_negative_buffer_size_fails(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + -1) + + def test_minibatch_preprocessor_invalid_indep_vartype_raises_exception(self): + self.module.get_expr_type = Mock(side_effect = ['integer', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size) + def test_minibatch_preprocessor_invalid_dep_vartype_raises_exception(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size) + + def test_minibatch_preprocessor_normalizing_const_zero_fails(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size, + 0) + + def test_minibatch_preprocessor_negative_normalizing_const_fails(self): + self.module.get_expr_type = Mock(side_effect = ['integer[]', 'integer[]']) + with self.assertRaises(plpy.PLPYException): + self.module.MiniBatchPreProcessorDL(self.default_schema_madlib, + self.default_source_table, + self.default_output_table, + self.default_dep_var, + self.default_ind_var, + self.default_buffer_size, + -1) if __name__ == '__main__': unittest.main() diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index e685bc4..ba7e960 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -394,7 +394,7 @@ def get_expr_type(expressions, tbl): """.format("ARRAY[{0}]".format(','.join(pg_type_expressions)), tbl)) if not expr_types: - plpy.error("Unable toget type of expression ({0}). " + plpy.error("Unable to get type of expression ({0}). " "Table {1} may not contain any valid tuples". format(expressions, tbl)) expr_types = expr_types[0]["all_types"]
