This is an automated email from the ASF dual-hosted git repository. jingyimei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 9fd12682fb1d6688482ce66aeddc09cded5e1dd4 Author: Jingyi Mei <[email protected]> AuthorDate: Fri Mar 15 17:38:42 2019 -0700 Add class_values and normalizing_const column in summary table Co-authored-by: Ekta Khanna <[email protected]> --- .../utilities/minibatch_preprocessing.py_in | 55 +++++++++++----------- .../utilities/minibatch_preprocessing_dl.sql_in | 19 ++------ .../test/minibatch_preprocessing_dl.sql_in | 30 +++++++++++- 3 files changed, 59 insertions(+), 45 deletions(-) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index 36fdd05..f3dc462 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -348,8 +348,7 @@ class MiniBatchPreProcessor: class MiniBatchPreProcessorDL(MiniBatchPreProcessor): def __init__(self, schema_madlib, source_table, output_table, dependent_varname, independent_varname, buffer_size, - normalizing_const=1.0, one_hot_encode_int_dep_var=False, - **kwargs): + normalizing_const=1.0, **kwargs): self.schema_madlib = schema_madlib self.source_table = source_table self.output_table = output_table @@ -361,18 +360,13 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self.output_summary_table = add_postfix(self.output_table, "_summary") self._validate_args() self.num_of_buffers = self._get_num_buffers() - self.to_one_hot_encode = one_hot_encode_int_dep_var - # self.to_one_hot_encode = self.should_one_hot_encode(one_hot_encode_int_dep_var) - if self.to_one_hot_encode: - if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): - # Assuming the input NUMERIC[] is already on_hot_encoded so casting to INTEGER[] - # self.dependent_varname = self.dependent_varname + '::INTEGER[]' - self.dependent_levels = None - else: - self.dependent_levels = get_distinct_col_levels( - self.source_table, self.dependent_varname, self.dependent_vartype) - else: + self.to_one_hot_encode = True + if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): self.dependent_levels = None + else: + self.dependent_levels = get_distinct_col_levels( + self.source_table, self.dependent_varname, self.dependent_vartype) + def minibatch_preprocessor_dl(self): # Create a temp table that has independent var normalized. @@ -380,6 +374,7 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): dependent_varname_with_offset = self.dependent_varname dep_var_array_expr = self.get_dep_var_array_expr() + # Assuming the input NUMERIC[] is already on_hot_encoded, so casting to INTEGER[] if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): dep_var_array_expr = dep_var_array_expr + '::INTEGER[]' scalar_mult_sql = """ @@ -418,6 +413,11 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self._create_output_summary_table() def _create_output_summary_table(self): + class_level_str='NULL::TEXT' + if self.dependent_levels: + class_level_str=py_list_to_sql_string( + self.dependent_levels, array_type=self.dependent_vartype, + long_format=True) query = """ CREATE TABLE {self.output_summary_table} AS SELECT @@ -426,8 +426,10 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS dependent_varname, $__madlib__${self.independent_varname}$__madlib__$::TEXT AS independent_varname, $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS dependent_vartype, - {self.buffer_size} AS buffer_size - """.format(self=self) + {class_level_str} AS class_values, + {self.buffer_size} AS buffer_size, + {self.normalizing_const} AS normalizing_const + """.format(self=self, class_level_str=class_level_str) plpy.execute(query) def _validate_args(self): @@ -443,13 +445,14 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): "one of {0}".format(','.join(NUMERIC))) self.dependent_vartype = get_expr_type( self.dependent_varname, self.source_table) - if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): - pass - else: - dep_valid_types = NUMERIC | TEXT | BOOLEAN - _assert(is_valid_psql_type(self.dependent_vartype, dep_valid_types), - "Invalid dependent variable type, should be one of {0}". - format(','.join(dep_valid_types))) + # The denpendent variable needs to be either: + # 1. NUMERIC, TEXT OR BOOLEAN, which we always one-hot encode + # 2. NUMERIC ARRAY, which we assume it is already one-hot encoded, and we + # just cast it the INTEGER ARRAY + _assert((is_valid_psql_type(self.dependent_vartype, NUMERIC | TEXT | BOOLEAN) or + is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY)), + """Invalid dependent variable type, should be one of the type in this list: + numeric, text, boolean, or numeric array""") if self.buffer_size is not None: _assert(self.buffer_size > 0, "minibatch_preprocessor_dl: The buffer size has to be a " @@ -467,12 +470,6 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0]) return ceil((1.0 * num_rows_in_tbl) / self.buffer_size) - # def should_one_hot_encode(self, one_hot_encode_int_dep_var): - # return (is_psql_char_type(self.dependent_vartype) or - # is_psql_boolean_type(self.dependent_vartype) or - # (is_psql_numeric_type(self.dependent_vartype) and - # one_hot_encode_int_dep_var)) - class MiniBatchStandardizer: """ @@ -787,6 +784,8 @@ class MiniBatchDocumentation: dependent_vartype -- Type of the dependent variable from the original table. buffer_size -- Buffer size used in preprocessing step. + normalizing_const -- Normalizing constant used for standardizing + arrays in independent_varname. --------------------------------------------------------------------------- """.format(**locals()) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in index 89afa20..7a45db2 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in @@ -501,8 +501,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( dependent_varname VARCHAR, independent_varname VARCHAR, buffer_size INTEGER, - normalizing_const DOUBLE PRECISION, - one_hot_encode_int_dep_var BOOLEAN + normalizing_const DOUBLE PRECISION ) RETURNS VOID AS $$ PythonFunctionBodyOnly(utilities, minibatch_preprocessing) from utilities.control import MinWarning @@ -518,21 +517,9 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( output_table VARCHAR, dependent_varname VARCHAR, independent_varname VARCHAR, - buffer_size INTEGER, - normalizing_const DOUBLE PRECISION -) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, FALSE); -$$ LANGUAGE sql VOLATILE -m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); - -CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( - source_table VARCHAR, - output_table VARCHAR, - dependent_varname VARCHAR, - independent_varname VARCHAR, buffer_size INTEGER ) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, FALSE); + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0); $$ LANGUAGE sql VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); @@ -542,7 +529,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( dependent_varname VARCHAR, independent_varname VARCHAR ) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, FALSE); + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0); $$ LANGUAGE sql VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in index cd0a7e6..31bfdd7 100644 --- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in @@ -101,6 +101,19 @@ SELECT minibatch_preprocessor_dl( SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM minibatch_preprocessor_dl_batch) a; SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM minibatch_preprocessor_dl_batch) a; +-- Test summary table +SELECT assert + ( + source_table = 'minibatch_preprocessor_dl_input' AND + output_table = 'minibatch_preprocessor_dl_batch' AND + dependent_varname = 'y' AND + independent_varname = 'x' AND + dependent_vartype = 'integer' AND + class_values = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12}' AND + buffer_size = 4 AND -- we sort the class values in python + normalizing_const = 5, + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary; -- Test one-hot encoding for dependent_var -- test boolean type @@ -116,6 +129,10 @@ SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode d SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; +SELECT assert (dependent_vartype = 'boolean' AND + class_values = '{f,t}', + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary; -- test text type DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; @@ -130,6 +147,10 @@ SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode d SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; +SELECT assert (dependent_vartype = 'text' AND + class_values = '{a,b,c}', + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary; -- test double precision type DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; @@ -144,7 +165,10 @@ SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode d SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; - +SELECT assert (dependent_vartype = 'double precision' AND + class_values = '{4.0,4.2,5.0}', + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary; -- test double precision array type DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; @@ -159,6 +183,10 @@ SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode d SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y4) as y4 FROM minibatch_preprocessor_dl_input) b; +SELECT assert (dependent_vartype = 'double precision[]' AND + class_values IS NULL, + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary; -- test integer array type DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary;
