This is an automated email from the ASF dual-hosted git repository.

jingyimei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 096dc52132baa580854ffc07fcc51cae7be4df8f
Author: Jingyi Mei <[email protected]>
AuthorDate: Thu Mar 14 16:57:37 2019 -0700

    DL: Add one-hot encode support for minibatch preprocessor
    
    Co-authored-by: Ekta Khanna <[email protected]>
---
 .../utilities/minibatch_preprocessing.py_in        | 244 ++++++++++++---------
 .../utilities/minibatch_preprocessing_dl.sql_in    |  22 +-
 .../test/minibatch_preprocessing_dl.sql_in         | 126 +++++++++--
 3 files changed, 259 insertions(+), 133 deletions(-)

diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index e1967a8..78e4ba1 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -34,6 +34,7 @@ from utilities import is_platform_pg
 from utilities import is_psql_boolean_type
 from utilities import is_psql_char_type
 from utilities import is_psql_int_type
+from utilities import is_psql_numeric_type
 from utilities import is_valid_psql_type
 from utilities import py_list_to_sql_string
 from utilities import split_quoted_delimited_str
@@ -58,112 +59,6 @@ MINIBATCH_OUTPUT_INDEPENDENT_COLNAME = "independent_varname"
 MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL = "dependent_var"
 MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL = "independent_var"
 
-class MiniBatchPreProcessorDL:
-    def __init__(self, schema_madlib, source_table, output_table,
-                 dependent_varname, independent_varname, buffer_size,
-                 normalizing_const=1.0, dependent_offset=None, **kwargs):
-        self.schema_madlib = schema_madlib
-        self.source_table = source_table
-        self.output_table = output_table
-        self.dependent_varname = dependent_varname
-        self.independent_varname = independent_varname
-        self.buffer_size = buffer_size
-        self.normalizing_const = normalizing_const if normalizing_const else 
1.0
-        self.dependent_offset = dependent_offset
-        self.module_name = "minibatch_preprocessor_DL"
-        self.output_summary_table = add_postfix(self.output_table, "_summary")
-        self._validate_args()
-        self.num_of_buffers = self._get_num_buffers()
-
-    def minibatch_preprocessor_dl(self):
-        # Create a temp table that has independent var normalized.
-        norm_tbl = unique_string(desp='normalized')
-
-        dependent_varname_with_offset = self.dependent_varname
-        if self.dependent_offset:
-            dependent_varname_with_offset = '{0} + 
{1}'.format(self.dependent_varname, self.dependent_offset)
-
-        scalar_mult_sql = """
-            CREATE TEMP TABLE {norm_tbl} AS
-            SELECT {self.schema_madlib}.array_scalar_mult(
-                {self.independent_varname}::REAL[], 
(1/{self.normalizing_const})::REAL) AS x_norm,
-                {dependent_varname_with_offset} AS y,
-                row_number() over() AS row_id
-            FROM {self.source_table} order by random()
-        """.format(**locals())
-        plpy.execute(scalar_mult_sql)
-        # Create the mini-batched output table
-        if is_platform_pg():
-            distributed_by_clause = ''
-        else:
-            distributed_by_clause= ' DISTRIBUTED BY (buffer_id) '
-        sql = """
-            CREATE TABLE {self.output_table} AS
-            SELECT * FROM
-            (
-                SELECT {self.schema_madlib}.agg_array_concat(
-                    ARRAY[{norm_tbl}.x_norm::REAL[]]) AS {x},
-                    array_agg({norm_tbl}.y) AS {y},
-                    ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS 
buffer_id
-                FROM {norm_tbl}
-                GROUP BY buffer_id
-            ) b
-            {distributed_by_clause}
-            """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL,
-                       y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL,
-                       **locals())
-        plpy.execute(sql)
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(norm_tbl))
-        # Create summary table
-        self._create_output_summary_table()
-
-    def _create_output_summary_table(self):
-        query = """
-            CREATE TABLE {self.output_summary_table} AS
-            SELECT
-                $__madlib__${self.source_table}$__madlib__$::TEXT AS 
source_table,
-                $__madlib__${self.output_table}$__madlib__$::TEXT AS 
output_table,
-                $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS 
dependent_varname,
-                $__madlib__${self.independent_varname}$__madlib__$::TEXT AS 
independent_varname,
-                $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS 
dependent_vartype,
-                {self.buffer_size} AS buffer_size
-            """.format(self=self)
-        plpy.execute(query)
-
-    def _validate_args(self):
-        validate_module_input_params(
-            self.source_table, self.output_table, self.independent_varname,
-            self.dependent_varname, self.module_name, None,
-            [self.output_summary_table])
-        self.independent_vartype = get_expr_type(
-            self.independent_varname, self.source_table)
-        _assert(is_valid_psql_type(self.independent_vartype,
-                                   NUMERIC | ONLY_ARRAY),
-                "Invalid independent variable type, should be an array of "
-                "one of {0}".format(','.join(NUMERIC)))
-        self.dependent_vartype = get_expr_type(
-            self.dependent_varname, self.source_table)
-        dep_valid_types = NUMERIC | TEXT | BOOLEAN
-        _assert(is_valid_psql_type(self.dependent_vartype, dep_valid_types),
-                "Invalid dependent variable type, should be one of {0}".
-                format(','.join(dep_valid_types)))
-        if self.buffer_size is not None:
-            _assert(self.buffer_size > 0,
-                    "minibatch_preprocessor_dl: The buffer size has to be a "
-                    "positive integer or NULL.")
-
-    def _get_num_buffers(self):
-        num_rows_in_tbl = plpy.execute("""
-                SELECT count(*) AS cnt FROM {0}
-            """.format(self.source_table))[0]['cnt']
-        buffer_size_calculator = MiniBatchBufferSizeCalculator()
-        indepdent_var_dim = _tbl_dimension_rownum(
-            self.schema_madlib, self.source_table,
-            self.independent_varname, skip_row_count=True)
-        self.buffer_size = 
buffer_size_calculator.calculate_default_buffer_size(
-            self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0])
-        return ceil((1.0 * num_rows_in_tbl) / self.buffer_size)
-
 class MiniBatchPreProcessor:
     """
     This class is responsible for executing the main logic of mini batch
@@ -333,6 +228,8 @@ class MiniBatchPreProcessor:
         else:
             return "ARRAY[({0})]".format(self.dependent_varname)
 
+
+
     def get_indep_var_array_expr(self):
         """ we assume that all the independent features are either numeric or
         already encoded by the user.
@@ -448,6 +345,141 @@ class MiniBatchPreProcessor:
         plpy.execute(query)
 
 
+class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
+    def __init__(self, schema_madlib, source_table, output_table,
+                 dependent_varname, independent_varname, buffer_size,
+                 normalizing_const=1.0, dependent_offset=None,
+                 one_hot_encode_int_dep_var=False, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.output_table = output_table
+        self.dependent_varname = dependent_varname
+        self.independent_varname = independent_varname
+        self.buffer_size = buffer_size
+        self.normalizing_const = normalizing_const if normalizing_const else 
1.0
+        self.dependent_offset = dependent_offset
+        self.module_name = "minibatch_preprocessor_DL"
+        self.output_summary_table = add_postfix(self.output_table, "_summary")
+        self._validate_args()
+        self.num_of_buffers = self._get_num_buffers()
+        self.to_one_hot_encode = one_hot_encode_int_dep_var
+        # self.to_one_hot_encode = 
self.should_one_hot_encode(one_hot_encode_int_dep_var)
+        if self.to_one_hot_encode:
+            if is_valid_psql_type(self.dependent_vartype, NUMERIC | 
ONLY_ARRAY):
+                # Assuming the input NUMERIC[] is already on_hot_encoded so 
casting to INTEGER[]
+                # self.dependent_varname = self.dependent_varname + 
'::INTEGER[]'
+                self.dependent_levels = None
+            else:
+                self.dependent_levels = get_distinct_col_levels(
+                self.source_table, self.dependent_varname, 
self.dependent_vartype)
+        else:
+            self.dependent_levels = None
+
+    def minibatch_preprocessor_dl(self):
+        # Create a temp table that has independent var normalized.
+        norm_tbl = unique_string(desp='normalized')
+
+        dependent_varname_with_offset = self.dependent_varname
+        if self.dependent_offset:
+            dependent_varname_with_offset = '{0} + 
{1}'.format(self.dependent_varname, self.dependent_offset)
+        dep_var_array_expr = self.get_dep_var_array_expr()
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+            dep_var_array_expr = dep_var_array_expr + '::INTEGER[]'
+        scalar_mult_sql = """
+            CREATE TEMP TABLE {norm_tbl} AS
+            SELECT {self.schema_madlib}.array_scalar_mult(
+                {self.independent_varname}::REAL[], 
(1/{self.normalizing_const})::REAL) AS x_norm,
+                {dependent_varname_with_offset} AS y1,
+                {dep_var_array_expr} AS y,
+                row_number() over() AS row_id
+            FROM {self.source_table}
+            """.format(**locals())
+        plpy.execute(scalar_mult_sql)
+        # Create the mini-batched output table
+        if is_platform_pg():
+            distributed_by_clause = ''
+        else:
+            distributed_by_clause= ' DISTRIBUTED BY (buffer_id) '
+        sql = """
+            CREATE TABLE {self.output_table} AS
+            SELECT * FROM
+            (
+                SELECT {self.schema_madlib}.agg_array_concat(
+                    ARRAY[{norm_tbl}.x_norm::REAL[]]) AS {x},
+                    {self.schema_madlib}.agg_array_concat(
+                    ARRAY[{norm_tbl}.y1]) AS y1,
+                    {self.schema_madlib}.agg_array_concat(
+                    ARRAY[{norm_tbl}.y]) AS {y},
+                    ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS 
buffer_id
+                FROM {norm_tbl}
+                GROUP BY buffer_id
+            ) b
+            {distributed_by_clause}
+            """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL,
+                       y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL,
+                       **locals())
+        plpy.execute(sql)
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(norm_tbl))
+        # Create summary table
+        self._create_output_summary_table()
+
+    def _create_output_summary_table(self):
+        query = """
+            CREATE TABLE {self.output_summary_table} AS
+            SELECT
+                $__madlib__${self.source_table}$__madlib__$::TEXT AS 
source_table,
+                $__madlib__${self.output_table}$__madlib__$::TEXT AS 
output_table,
+                $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS 
dependent_varname,
+                $__madlib__${self.independent_varname}$__madlib__$::TEXT AS 
independent_varname,
+                $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS 
dependent_vartype,
+                {self.buffer_size} AS buffer_size
+            """.format(self=self)
+        plpy.execute(query)
+
+    def _validate_args(self):
+        validate_module_input_params(
+            self.source_table, self.output_table, self.independent_varname,
+            self.dependent_varname, self.module_name, None,
+            [self.output_summary_table])
+        self.independent_vartype = get_expr_type(
+            self.independent_varname, self.source_table)
+        _assert(is_valid_psql_type(self.independent_vartype,
+                                   NUMERIC | ONLY_ARRAY),
+                "Invalid independent variable type, should be an array of "
+                "one of {0}".format(','.join(NUMERIC)))
+        self.dependent_vartype = get_expr_type(
+            self.dependent_varname, self.source_table)
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+            pass
+        else:
+            dep_valid_types = NUMERIC | TEXT | BOOLEAN
+            _assert(is_valid_psql_type(self.dependent_vartype, 
dep_valid_types),
+                    "Invalid dependent variable type, should be one of {0}".
+                    format(','.join(dep_valid_types)))
+        if self.buffer_size is not None:
+            _assert(self.buffer_size > 0,
+                    "minibatch_preprocessor_dl: The buffer size has to be a "
+                    "positive integer or NULL.")
+
+    def _get_num_buffers(self):
+        num_rows_in_tbl = plpy.execute("""
+                SELECT count(*) AS cnt FROM {0}
+            """.format(self.source_table))[0]['cnt']
+        buffer_size_calculator = MiniBatchBufferSizeCalculator()
+        indepdent_var_dim = _tbl_dimension_rownum(
+            self.schema_madlib, self.source_table,
+            self.independent_varname, skip_row_count=True)
+        self.buffer_size = 
buffer_size_calculator.calculate_default_buffer_size(
+            self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0])
+        return ceil((1.0 * num_rows_in_tbl) / self.buffer_size)
+
+    # def should_one_hot_encode(self, one_hot_encode_int_dep_var):
+    #     return (is_psql_char_type(self.dependent_vartype) or
+    #             is_psql_boolean_type(self.dependent_vartype) or
+    #             (is_psql_numeric_type(self.dependent_vartype) and
+    #                 one_hot_encode_int_dep_var))
+
+
 class MiniBatchStandardizer:
     """
     This class is responsible for
diff --git 
a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
index 1a13d35..5bf4767 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
@@ -494,7 +494,8 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     independent_varname         VARCHAR,
     buffer_size                 INTEGER,
     normalizing_const           DOUBLE PRECISION,
-    dependent_offset            INTEGER
+    dependent_offset            INTEGER,
+    one_hot_encode_int_dep_var  BOOLEAN
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(utilities, minibatch_preprocessing)
     from utilities.control import MinWarning
@@ -511,9 +512,22 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR,
     buffer_size             INTEGER,
+    normalizing_const       DOUBLE PRECISION,
+    dependent_offset        INTEGER
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, $7, 
FALSE);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl(
+    source_table            VARCHAR,
+    output_table            VARCHAR,
+    dependent_varname       VARCHAR,
+    independent_varname     VARCHAR,
+    buffer_size             INTEGER,
     normalizing_const       DOUBLE PRECISION
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL, 
FALSE);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -524,7 +538,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     independent_varname     VARCHAR,
     buffer_size             INTEGER
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, 
NULL);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, 
NULL, FALSE);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -534,7 +548,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, 
NULL);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, 
NULL, FALSE);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
index 45da10f..7130044 100644
--- 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
+++ 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
@@ -51,11 +51,14 @@ SELECT minibatch_preprocessor_dl(
 SELECT assert(count(*)=4, 'Incorrect number of buffers in 
minibatch_preprocessor_dl_batch.')
 FROM minibatch_preprocessor_dl_batch;
 
+SELECT assert(array_upper(independent_var, 2)=6, 'Incorrect buffer size.')
+FROM minibatch_preprocessor_dl_batch WHERE buffer_id=0;
+
 SELECT assert(array_upper(independent_var, 1)=5, 'Incorrect buffer size.')
 FROM minibatch_preprocessor_dl_batch WHERE buffer_id=1;
 
-SELECT assert(array_upper(independent_var, 1)=2, 'Incorrect buffer size.')
-FROM minibatch_preprocessor_dl_batch WHERE buffer_id=4;
+SELECT assert(array_upper(independent_var, 1)=4, 'Incorrect buffer size.')
+FROM minibatch_preprocessor_dl_batch WHERE buffer_id=3;
 
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
 SELECT minibatch_preprocessor_dl(
@@ -65,25 +68,25 @@ SELECT minibatch_preprocessor_dl(
   'x');
 
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_input;
-CREATE TABLE minibatch_preprocessor_dl_input(id serial, x double precision[], 
y INTEGER);
-INSERT INTO minibatch_preprocessor_dl_input(x, y) VALUES
-(ARRAY[1,2,3,4,5,6], 4),
-(ARRAY[11,2,3,4,5,6], 3),
-(ARRAY[11,22,33,4,5,6], 8),
-(ARRAY[11,22,33,44,5,6], 2),
-(ARRAY[11,22,33,44,65,6], 5),
-(ARRAY[11,22,33,44,65,56], 6),
-(ARRAY[11,22,33,44,65,56], 2),
-(ARRAY[11,22,33,44,65,56], 10),
-(ARRAY[11,22,33,44,65,56], 3),
-(ARRAY[11,22,33,44,65,56], 7),
-(ARRAY[11,22,33,44,65,56], 6),
-(ARRAY[11,22,33,44,65,56], -6),
-(ARRAY[11,22,33,144,65,56], 9),
-(ARRAY[11,22,233,44,65,56], 0),
-(ARRAY[11,22,33,44,65,56], 12),
-(ARRAY[11,22,33,44,65,56], -3),
-(ARRAY[11,22,33,44,65,56], -1);
+CREATE TABLE minibatch_preprocessor_dl_input(id serial, x double precision[], 
y INTEGER, y1 BOOLEAN, y2 TEXT, y3 DOUBLE PRECISION, y4 DOUBLE PRECISION[], y5 
INTEGER[]);
+INSERT INTO minibatch_preprocessor_dl_input(x, y, y1, y2, y3, y4, y5) VALUES
+(ARRAY[1,2,3,4,5,6], 4, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,2,3,4,5,6], 3, TRUE, 'c', 4.2, ARRAY[0.0, 1.0], ARRAY[1,0]),
+(ARRAY[11,22,33,4,5,6], 8, TRUE, 'a', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,5,6], 2, FALSE, 'a', 4.2, ARRAY[0.0, 1.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,6], 5, TRUE, 'b', 4.0, ARRAY[0.0, 1.0], ARRAY[0,1]),
+(ARRAY[11,22,33,44,65,56], 6, TRUE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], 2, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], 10, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], 3, TRUE, 'b', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], 7, FALSE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], 6, TRUE, 'a', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], -6, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,144,65,56], 9, TRUE, 'c', 4.0, ARRAY[0.0, 1.0], ARRAY[1,0]),
+(ARRAY[11,22,233,44,65,56], 0, TRUE, 'a', 5.0, ARRAY[1.0, 0.0], ARRAY[0,1]),
+(ARRAY[11,22,33,44,65,56], 12, TRUE, 'a', 4.0, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], -3, FALSE, 'a', 4.2, ARRAY[1.0, 0.0], ARRAY[1,0]),
+(ARRAY[11,22,33,44,65,56], -1, TRUE, 'b', 4.0, ARRAY[0.0, 1.0], ARRAY[0,1]);
 
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
 SELECT minibatch_preprocessor_dl(
@@ -109,7 +112,7 @@ SELECT minibatch_preprocessor_dl(
   6);
 
 -- Test that dependent vars gets shifted by +6, by verifying minimum value 
goes from -6 to 0
-SELECT assert(abs(MIN(y))<0.00001, 'Dependent var not shifted properly!') FROM 
(SELECT UNNEST(dependent_var) as y FROM minibatch_preprocessor_dl_batch) a;
+SELECT assert(abs(MIN(y))<0.00001, 'Dependent var not shifted properly!') FROM 
(SELECT UNNEST(y1) as y FROM minibatch_preprocessor_dl_batch) a;
 
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
 SELECT minibatch_preprocessor_dl(
@@ -122,4 +125,81 @@ SELECT minibatch_preprocessor_dl(
   -6);
 
 -- Test that dependent vars gets shifted by -6, by verifying minimum value 
goes from -6 to -12
-SELECT assert(relative_error(MIN(y), -12)<0.00001, 'Dependent var not shifted 
properly!') FROM (SELECT UNNEST(dependent_var) as y FROM 
minibatch_preprocessor_dl_batch) a;
\ No newline at end of file
+SELECT assert(relative_error(MIN(y), -12)<0.00001, 'Dependent var not shifted 
properly!') FROM (SELECT UNNEST(y1) as y FROM minibatch_preprocessor_dl_batch) 
a;
+
+
+-- Test one-hot encoding for dependent_var
+-- test boolean type
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input',
+  'minibatch_preprocessor_dl_batch',
+  'y1',
+  'x',
+  4,
+  5);
+SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM 
minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode 
dimension') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
+
+-- test text type
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input',
+  'minibatch_preprocessor_dl_batch',
+  'y2',
+  'x',
+  4,
+  5);
+SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM 
minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode 
dimension') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
+
+-- test double precision type
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input',
+  'minibatch_preprocessor_dl_batch',
+  'y3',
+  'x',
+  4,
+  5);
+SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM 
minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode 
dimension') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
+
+
+-- test double precision array type
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input',
+  'minibatch_preprocessor_dl_batch',
+  'y4',
+  'x',
+  4,
+  5);
+SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM 
minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode 
dimension') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot 
encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM 
minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y4) as y4 FROM 
minibatch_preprocessor_dl_input) b;
+
+-- test integer array type
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input',
+  'minibatch_preprocessor_dl_batch',
+  'y5',
+  'x',
+  4,
+  5);
+SELECT assert(pg_typeof(dependent_var) = 'integer[]'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM 
minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode 
dimension') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+SELECT assert(relative_error(SUM(y), SUM(y5)) < 0.000001, 'Incorrect one-hot 
encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM 
minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y5) as y5 FROM 
minibatch_preprocessor_dl_input) b;
+SELECT assert (dependent_vartype   = 'integer[]' AND
+               class_values        IS NULL,
+               'Summary Validation failed. Actual:' || __to_char(summary)
+              ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;

Reply via email to