This is an automated email from the ASF dual-hosted git repository. jingyimei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 096dc52132baa580854ffc07fcc51cae7be4df8f Author: Jingyi Mei <[email protected]> AuthorDate: Thu Mar 14 16:57:37 2019 -0700 DL: Add one-hot encode support for minibatch preprocessor Co-authored-by: Ekta Khanna <[email protected]> --- .../utilities/minibatch_preprocessing.py_in | 244 ++++++++++++--------- .../utilities/minibatch_preprocessing_dl.sql_in | 22 +- .../test/minibatch_preprocessing_dl.sql_in | 126 +++++++++-- 3 files changed, 259 insertions(+), 133 deletions(-) diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index e1967a8..78e4ba1 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -34,6 +34,7 @@ from utilities import is_platform_pg from utilities import is_psql_boolean_type from utilities import is_psql_char_type from utilities import is_psql_int_type +from utilities import is_psql_numeric_type from utilities import is_valid_psql_type from utilities import py_list_to_sql_string from utilities import split_quoted_delimited_str @@ -58,112 +59,6 @@ MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname" MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL = "dependent_var" MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL = "independent_var" -class MiniBatchPreProcessorDL: - def __init__(self, schema_madlib, source_table, output_table, - dependent_varname, independent_varname, buffer_size, - normalizing_const=1.0, dependent_offset=None, **kwargs): - self.schema_madlib = schema_madlib - self.source_table = source_table - self.output_table = output_table - self.dependent_varname = dependent_varname - self.independent_varname = independent_varname - self.buffer_size = buffer_size - self.normalizing_const = normalizing_const if normalizing_const else 1.0 - self.dependent_offset = dependent_offset - self.module_name = "minibatch_preprocessor_DL" - self.output_summary_table = add_postfix(self.output_table, "_summary") - self._validate_args() - self.num_of_buffers = self._get_num_buffers() - - def minibatch_preprocessor_dl(self): - # Create a temp table that has independent var normalized. - norm_tbl = unique_string(desp='normalized') - - dependent_varname_with_offset = self.dependent_varname - if self.dependent_offset: - dependent_varname_with_offset = '{0} + {1}'.format(self.dependent_varname, self.dependent_offset) - - scalar_mult_sql = """ - CREATE TEMP TABLE {norm_tbl} AS - SELECT {self.schema_madlib}.array_scalar_mult( - {self.independent_varname}::REAL[], (1/{self.normalizing_const})::REAL) AS x_norm, - {dependent_varname_with_offset} AS y, - row_number() over() AS row_id - FROM {self.source_table} order by random() - """.format(**locals()) - plpy.execute(scalar_mult_sql) - # Create the mini-batched output table - if is_platform_pg(): - distributed_by_clause = '' - else: - distributed_by_clause= ' DISTRIBUTED BY (buffer_id) ' - sql = """ - CREATE TABLE {self.output_table} AS - SELECT * FROM - ( - SELECT {self.schema_madlib}.agg_array_concat( - ARRAY[{norm_tbl}.x_norm::REAL[]]) AS {x}, - array_agg({norm_tbl}.y) AS {y}, - ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS buffer_id - FROM {norm_tbl} - GROUP BY buffer_id - ) b - {distributed_by_clause} - """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, - y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, - **locals()) - plpy.execute(sql) - plpy.execute("DROP TABLE IF EXISTS {0}".format(norm_tbl)) - # Create summary table - self._create_output_summary_table() - - def _create_output_summary_table(self): - query = """ - CREATE TABLE {self.output_summary_table} AS - SELECT - $__madlib__${self.source_table}$__madlib__$::TEXT AS source_table, - $__madlib__${self.output_table}$__madlib__$::TEXT AS output_table, - $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS dependent_varname, - $__madlib__${self.independent_varname}$__madlib__$::TEXT AS independent_varname, - $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS dependent_vartype, - {self.buffer_size} AS buffer_size - """.format(self=self) - plpy.execute(query) - - def _validate_args(self): - validate_module_input_params( - self.source_table, self.output_table, self.independent_varname, - self.dependent_varname, self.module_name, None, - [self.output_summary_table]) - self.independent_vartype = get_expr_type( - self.independent_varname, self.source_table) - _assert(is_valid_psql_type(self.independent_vartype, - NUMERIC | ONLY_ARRAY), - "Invalid independent variable type, should be an array of " - "one of {0}".format(','.join(NUMERIC))) - self.dependent_vartype = get_expr_type( - self.dependent_varname, self.source_table) - dep_valid_types = NUMERIC | TEXT | BOOLEAN - _assert(is_valid_psql_type(self.dependent_vartype, dep_valid_types), - "Invalid dependent variable type, should be one of {0}". - format(','.join(dep_valid_types))) - if self.buffer_size is not None: - _assert(self.buffer_size > 0, - "minibatch_preprocessor_dl: The buffer size has to be a " - "positive integer or NULL.") - - def _get_num_buffers(self): - num_rows_in_tbl = plpy.execute(""" - SELECT count(*) AS cnt FROM {0} - """.format(self.source_table))[0]['cnt'] - buffer_size_calculator = MiniBatchBufferSizeCalculator() - indepdent_var_dim = _tbl_dimension_rownum( - self.schema_madlib, self.source_table, - self.independent_varname, skip_row_count=True) - self.buffer_size = buffer_size_calculator.calculate_default_buffer_size( - self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0]) - return ceil((1.0 * num_rows_in_tbl) / self.buffer_size) - class MiniBatchPreProcessor: """ This class is responsible for executing the main logic of mini batch @@ -333,6 +228,8 @@ class MiniBatchPreProcessor: else: return "ARRAY[({0})]".format(self.dependent_varname) + + def get_indep_var_array_expr(self): """ we assume that all the independent features are either numeric or already encoded by the user. @@ -448,6 +345,141 @@ class MiniBatchPreProcessor: plpy.execute(query) +class MiniBatchPreProcessorDL(MiniBatchPreProcessor): + def __init__(self, schema_madlib, source_table, output_table, + dependent_varname, independent_varname, buffer_size, + normalizing_const=1.0, dependent_offset=None, + one_hot_encode_int_dep_var=False, **kwargs): + self.schema_madlib = schema_madlib + self.source_table = source_table + self.output_table = output_table + self.dependent_varname = dependent_varname + self.independent_varname = independent_varname + self.buffer_size = buffer_size + self.normalizing_const = normalizing_const if normalizing_const else 1.0 + self.dependent_offset = dependent_offset + self.module_name = "minibatch_preprocessor_DL" + self.output_summary_table = add_postfix(self.output_table, "_summary") + self._validate_args() + self.num_of_buffers = self._get_num_buffers() + self.to_one_hot_encode = one_hot_encode_int_dep_var + # self.to_one_hot_encode = self.should_one_hot_encode(one_hot_encode_int_dep_var) + if self.to_one_hot_encode: + if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): + # Assuming the input NUMERIC[] is already on_hot_encoded so casting to INTEGER[] + # self.dependent_varname = self.dependent_varname + '::INTEGER[]' + self.dependent_levels = None + else: + self.dependent_levels = get_distinct_col_levels( + self.source_table, self.dependent_varname, self.dependent_vartype) + else: + self.dependent_levels = None + + def minibatch_preprocessor_dl(self): + # Create a temp table that has independent var normalized. + norm_tbl = unique_string(desp='normalized') + + dependent_varname_with_offset = self.dependent_varname + if self.dependent_offset: + dependent_varname_with_offset = '{0} + {1}'.format(self.dependent_varname, self.dependent_offset) + dep_var_array_expr = self.get_dep_var_array_expr() + if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): + dep_var_array_expr = dep_var_array_expr + '::INTEGER[]' + scalar_mult_sql = """ + CREATE TEMP TABLE {norm_tbl} AS + SELECT {self.schema_madlib}.array_scalar_mult( + {self.independent_varname}::REAL[], (1/{self.normalizing_const})::REAL) AS x_norm, + {dependent_varname_with_offset} AS y1, + {dep_var_array_expr} AS y, + row_number() over() AS row_id + FROM {self.source_table} + """.format(**locals()) + plpy.execute(scalar_mult_sql) + # Create the mini-batched output table + if is_platform_pg(): + distributed_by_clause = '' + else: + distributed_by_clause= ' DISTRIBUTED BY (buffer_id) ' + sql = """ + CREATE TABLE {self.output_table} AS + SELECT * FROM + ( + SELECT {self.schema_madlib}.agg_array_concat( + ARRAY[{norm_tbl}.x_norm::REAL[]]) AS {x}, + {self.schema_madlib}.agg_array_concat( + ARRAY[{norm_tbl}.y1]) AS y1, + {self.schema_madlib}.agg_array_concat( + ARRAY[{norm_tbl}.y]) AS {y}, + ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS buffer_id + FROM {norm_tbl} + GROUP BY buffer_id + ) b + {distributed_by_clause} + """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, + y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, + **locals()) + plpy.execute(sql) + plpy.execute("DROP TABLE IF EXISTS {0}".format(norm_tbl)) + # Create summary table + self._create_output_summary_table() + + def _create_output_summary_table(self): + query = """ + CREATE TABLE {self.output_summary_table} AS + SELECT + $__madlib__${self.source_table}$__madlib__$::TEXT AS source_table, + $__madlib__${self.output_table}$__madlib__$::TEXT AS output_table, + $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS dependent_varname, + $__madlib__${self.independent_varname}$__madlib__$::TEXT AS independent_varname, + $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS dependent_vartype, + {self.buffer_size} AS buffer_size + """.format(self=self) + plpy.execute(query) + + def _validate_args(self): + validate_module_input_params( + self.source_table, self.output_table, self.independent_varname, + self.dependent_varname, self.module_name, None, + [self.output_summary_table]) + self.independent_vartype = get_expr_type( + self.independent_varname, self.source_table) + _assert(is_valid_psql_type(self.independent_vartype, + NUMERIC | ONLY_ARRAY), + "Invalid independent variable type, should be an array of " + "one of {0}".format(','.join(NUMERIC))) + self.dependent_vartype = get_expr_type( + self.dependent_varname, self.source_table) + if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): + pass + else: + dep_valid_types = NUMERIC | TEXT | BOOLEAN + _assert(is_valid_psql_type(self.dependent_vartype, dep_valid_types), + "Invalid dependent variable type, should be one of {0}". + format(','.join(dep_valid_types))) + if self.buffer_size is not None: + _assert(self.buffer_size > 0, + "minibatch_preprocessor_dl: The buffer size has to be a " + "positive integer or NULL.") + + def _get_num_buffers(self): + num_rows_in_tbl = plpy.execute(""" + SELECT count(*) AS cnt FROM {0} + """.format(self.source_table))[0]['cnt'] + buffer_size_calculator = MiniBatchBufferSizeCalculator() + indepdent_var_dim = _tbl_dimension_rownum( + self.schema_madlib, self.source_table, + self.independent_varname, skip_row_count=True) + self.buffer_size = buffer_size_calculator.calculate_default_buffer_size( + self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0]) + return ceil((1.0 * num_rows_in_tbl) / self.buffer_size) + + # def should_one_hot_encode(self, one_hot_encode_int_dep_var): + # return (is_psql_char_type(self.dependent_vartype) or + # is_psql_boolean_type(self.dependent_vartype) or + # (is_psql_numeric_type(self.dependent_vartype) and + # one_hot_encode_int_dep_var)) + + class MiniBatchStandardizer: """ This class is responsible for diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in index 1a13d35..5bf4767 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in @@ -494,7 +494,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( independent_varname VARCHAR, buffer_size INTEGER, normalizing_const DOUBLE PRECISION, - dependent_offset INTEGER + dependent_offset INTEGER, + one_hot_encode_int_dep_var BOOLEAN ) RETURNS VOID AS $$ PythonFunctionBodyOnly(utilities, minibatch_preprocessing) from utilities.control import MinWarning @@ -511,9 +512,22 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( dependent_varname VARCHAR, independent_varname VARCHAR, buffer_size INTEGER, + normalizing_const DOUBLE PRECISION, + dependent_offset INTEGER +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, $7, FALSE); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( + source_table VARCHAR, + output_table VARCHAR, + dependent_varname VARCHAR, + independent_varname VARCHAR, + buffer_size INTEGER, normalizing_const DOUBLE PRECISION ) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL); + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL, FALSE); $$ LANGUAGE sql VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); @@ -524,7 +538,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( independent_varname VARCHAR, buffer_size INTEGER ) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, NULL); + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, NULL, FALSE); $$ LANGUAGE sql VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); @@ -534,7 +548,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl( dependent_varname VARCHAR, independent_varname VARCHAR ) RETURNS VOID AS $$ - SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, NULL); + SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, NULL, FALSE); $$ LANGUAGE sql VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in index 45da10f..7130044 100644 --- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in +++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in @@ -51,11 +51,14 @@ SELECT minibatch_preprocessor_dl( SELECT assert(count(*)=4, 'Incorrect number of buffers in minibatch_preprocessor_dl_batch.') FROM minibatch_preprocessor_dl_batch; +SELECT assert(array_upper(independent_var, 2)=6, 'Incorrect buffer size.') +FROM minibatch_preprocessor_dl_batch WHERE buffer_id=0; + SELECT assert(array_upper(independent_var, 1)=5, 'Incorrect buffer size.') FROM minibatch_preprocessor_dl_batch WHERE buffer_id=1; -SELECT assert(array_upper(independent_var, 1)=2, 'Incorrect buffer size.') -FROM minibatch_preprocessor_dl_batch WHERE buffer_id=4; +SELECT assert(array_upper(independent_var, 1)=4, 'Incorrect buffer size.') +FROM minibatch_preprocessor_dl_batch WHERE buffer_id=3; DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; SELECT minibatch_preprocessor_dl( @@ -65,25 +68,25 @@ SELECT minibatch_preprocessor_dl( 'x'); DROP TABLE IF EXISTS minibatch_preprocessor_dl_input; -CREATE TABLE minibatch_preprocessor_dl_input(id serial, x double precision[], y INTEGER); -INSERT INTO minibatch_preprocessor_dl_input(x, y) VALUES -(ARRAY[1,2,3,4,5,6], 4), -(ARRAY[11,2,3,4,5,6], 3), -(ARRAY[11,22,33,4,5,6], 8), -(ARRAY[11,22,33,44,5,6], 2), -(ARRAY[11,22,33,44,65,6], 5), -(ARRAY[11,22,33,44,65,56], 6), -(ARRAY[11,22,33,44,65,56], 2), -(ARRAY[11,22,33,44,65,56], 10), -(ARRAY[11,22,33,44,65,56], 3), -(ARRAY[11,22,33,44,65,56], 7), -(ARRAY[11,22,33,44,65,56], 6), -(ARRAY[11,22,33,44,65,56], -6), -(ARRAY[11,22,33,144,65,56], 9), -(ARRAY[11,22,233,44,65,56], 0), -(ARRAY[11,22,33,44,65,56], 12), -(ARRAY[11,22,33,44,65,56], -3), -(ARRAY[11,22,33,44,65,56], -1); +CREATE TABLE minibatch_preprocessor_dl_input(id serial, x double precision[], y INTEGER, y1 BOOLEAN, y2 TEXT, y3 DOUBLE PRECISION, y4 DOUBLE PRECISION[], y5 INTEGER[]); +INSERT INTO minibatch_preprocessor_dl_input(x, y, y1, y2, y3, y4, y5) VALUES +(ARRAY[1,2,3,4,5,6], 4, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,2,3,4,5,6], 3, TRUE, 'c', 4.2, ARRAY[0.0, 1.0], ARRAY[1,0]), +(ARRAY[11,22,33,4,5,6], 8, TRUE, 'a', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,5,6], 2, FALSE, 'a', 4.2, ARRAY[0.0, 1.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,6], 5, TRUE, 'b', 4.0, ARRAY[0.0, 1.0], ARRAY[0,1]), +(ARRAY[11,22,33,44,65,56], 6, TRUE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], 2, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], 10, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], 3, TRUE, 'b', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], 7, FALSE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], 6, TRUE, 'a', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], -6, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,144,65,56], 9, TRUE, 'c', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]), +(ARRAY[11,22,233,44,65,56], 0, TRUE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[0,1]), +(ARRAY[11,22,33,44,65,56], 12, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], -3, FALSE, 'a', 4.2, ARRAY[1.0, 0.0], ARRAY[1,0]), +(ARRAY[11,22,33,44,65,56], -1, TRUE, 'b', 4.0, ARRAY[0.0, 1.0], ARRAY[0,1]); DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; SELECT minibatch_preprocessor_dl( @@ -109,7 +112,7 @@ SELECT minibatch_preprocessor_dl( 6); -- Test that dependent vars gets shifted by +6, by verifying minimum value goes from -6 to 0 -SELECT assert(abs(MIN(y))<0.00001, 'Dependent var not shifted properly!') FROM (SELECT UNNEST(dependent_var) as y FROM minibatch_preprocessor_dl_batch) a; +SELECT assert(abs(MIN(y))<0.00001, 'Dependent var not shifted properly!') FROM (SELECT UNNEST(y1) as y FROM minibatch_preprocessor_dl_batch) a; DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; SELECT minibatch_preprocessor_dl( @@ -122,4 +125,81 @@ SELECT minibatch_preprocessor_dl( -6); -- Test that dependent vars gets shifted by -6, by verifying minimum value goes from -6 to -12 -SELECT assert(relative_error(MIN(y), -12)<0.00001, 'Dependent var not shifted properly!') FROM (SELECT UNNEST(dependent_var) as y FROM minibatch_preprocessor_dl_batch) a; \ No newline at end of file +SELECT assert(relative_error(MIN(y), -12)<0.00001, 'Dependent var not shifted properly!') FROM (SELECT UNNEST(y1) as y FROM minibatch_preprocessor_dl_batch) a; + + +-- Test one-hot encoding for dependent_var +-- test boolean type +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'y1', + 'x', + 4, + 5); +SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM + minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; + +-- test text type +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'y2', + 'x', + 4, + 5); +SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM + minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; + +-- test double precision type +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'y3', + 'x', + 4, + 5); +SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM + minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0; + + +-- test double precision array type +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'y4', + 'x', + 4, + 5); +SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM + minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y4) as y4 FROM minibatch_preprocessor_dl_input) b; + +-- test integer array type +DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, minibatch_preprocessor_dl_batch_summary; +SELECT minibatch_preprocessor_dl( + 'minibatch_preprocessor_dl_input', + 'minibatch_preprocessor_dl_batch', + 'y5', + 'x', + 4, + 5); +SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM + minibatch_preprocessor_dl_batch WHERE buffer_id = 0; +SELECT assert(relative_error(SUM(y), SUM(y5)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y5) as y5 FROM minibatch_preprocessor_dl_input) b; +SELECT assert (dependent_vartype = 'integer[]' AND + class_values IS NULL, + 'Summary Validation failed. Actual:' || __to_char(summary) + ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
