This is an automated email from the ASF dual-hosted git repository.

jingyimei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 9fd12682fb1d6688482ce66aeddc09cded5e1dd4
Author: Jingyi Mei <[email protected]>
AuthorDate: Fri Mar 15 17:38:42 2019 -0700

    Add class_values and normalizing_const column in summary table
    
    Co-authored-by: Ekta Khanna <[email protected]>
---
 .../utilities/minibatch_preprocessing.py_in        | 55 +++++++++++-----------
 .../utilities/minibatch_preprocessing_dl.sql_in    | 19 ++------
 .../test/minibatch_preprocessing_dl.sql_in         | 30 +++++++++++-
 3 files changed, 59 insertions(+), 45 deletions(-)

diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index 36fdd05..f3dc462 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -348,8 +348,7 @@ class MiniBatchPreProcessor:
 class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
     def __init__(self, schema_madlib, source_table, output_table,
                  dependent_varname, independent_varname, buffer_size,
-                 normalizing_const=1.0, one_hot_encode_int_dep_var=False,
-                 **kwargs):
+                 normalizing_const=1.0, **kwargs):
         self.schema_madlib = schema_madlib
         self.source_table = source_table
         self.output_table = output_table
@@ -361,18 +360,13 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
         self.output_summary_table = add_postfix(self.output_table, "_summary")
         self._validate_args()
         self.num_of_buffers = self._get_num_buffers()
-        self.to_one_hot_encode = one_hot_encode_int_dep_var
-        # self.to_one_hot_encode = 
self.should_one_hot_encode(one_hot_encode_int_dep_var)
-        if self.to_one_hot_encode:
-            if is_valid_psql_type(self.dependent_vartype, NUMERIC | 
ONLY_ARRAY):
-                # Assuming the input NUMERIC[] is already on_hot_encoded so 
casting to INTEGER[]
-                # self.dependent_varname = self.dependent_varname + 
'::INTEGER[]'
-                self.dependent_levels = None
-            else:
-                self.dependent_levels = get_distinct_col_levels(
-                self.source_table, self.dependent_varname, 
self.dependent_vartype)
-        else:
+        self.to_one_hot_encode = True
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
             self.dependent_levels = None
+        else:
+            self.dependent_levels = get_distinct_col_levels(
+                self.source_table, self.dependent_varname, 
self.dependent_vartype)
+
 
     def minibatch_preprocessor_dl(self):
         # Create a temp table that has independent var normalized.
@@ -380,6 +374,7 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
 
         dependent_varname_with_offset = self.dependent_varname
         dep_var_array_expr = self.get_dep_var_array_expr()
+        # Assuming the input NUMERIC[] is already on_hot_encoded, so casting 
to INTEGER[]
         if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
             dep_var_array_expr = dep_var_array_expr + '::INTEGER[]'
         scalar_mult_sql = """
@@ -418,6 +413,11 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
         self._create_output_summary_table()
 
     def _create_output_summary_table(self):
+        class_level_str='NULL::TEXT'
+        if self.dependent_levels:
+            class_level_str=py_list_to_sql_string(
+                self.dependent_levels, array_type=self.dependent_vartype,
+                long_format=True)
         query = """
             CREATE TABLE {self.output_summary_table} AS
             SELECT
@@ -426,8 +426,10 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
                 $__madlib__${self.dependent_varname}$__madlib__$::TEXT AS 
dependent_varname,
                 $__madlib__${self.independent_varname}$__madlib__$::TEXT AS 
independent_varname,
                 $__madlib__${self.dependent_vartype}$__madlib__$::TEXT AS 
dependent_vartype,
-                {self.buffer_size} AS buffer_size
-            """.format(self=self)
+                {class_level_str} AS class_values,
+                {self.buffer_size} AS buffer_size,
+                {self.normalizing_const} AS normalizing_const
+            """.format(self=self, class_level_str=class_level_str)
         plpy.execute(query)
 
     def _validate_args(self):
@@ -443,13 +445,14 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
                 "one of {0}".format(','.join(NUMERIC)))
         self.dependent_vartype = get_expr_type(
             self.dependent_varname, self.source_table)
-        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
-            pass
-        else:
-            dep_valid_types = NUMERIC | TEXT | BOOLEAN
-            _assert(is_valid_psql_type(self.dependent_vartype, 
dep_valid_types),
-                    "Invalid dependent variable type, should be one of {0}".
-                    format(','.join(dep_valid_types)))
+        # The denpendent variable needs to be either:
+        # 1. NUMERIC, TEXT OR BOOLEAN, which we always one-hot encode
+        # 2. NUMERIC ARRAY, which we assume it is already one-hot encoded, and 
we
+        #    just cast it the INTEGER ARRAY
+        _assert((is_valid_psql_type(self.dependent_vartype, NUMERIC | TEXT | 
BOOLEAN) or
+                is_valid_psql_type(self.dependent_vartype, NUMERIC | 
ONLY_ARRAY)),
+                """Invalid dependent variable type, should be one of the type 
in this list:
+                numeric, text, boolean, or numeric array""")
         if self.buffer_size is not None:
             _assert(self.buffer_size > 0,
                     "minibatch_preprocessor_dl: The buffer size has to be a "
@@ -467,12 +470,6 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
             self.buffer_size, num_rows_in_tbl, indepdent_var_dim[0])
         return ceil((1.0 * num_rows_in_tbl) / self.buffer_size)
 
-    # def should_one_hot_encode(self, one_hot_encode_int_dep_var):
-    #     return (is_psql_char_type(self.dependent_vartype) or
-    #             is_psql_boolean_type(self.dependent_vartype) or
-    #             (is_psql_numeric_type(self.dependent_vartype) and
-    #                 one_hot_encode_int_dep_var))
-
 
 class MiniBatchStandardizer:
     """
@@ -787,6 +784,8 @@ class MiniBatchDocumentation:
         dependent_vartype         -- Type of the dependent variable from the
                                      original table.
         buffer_size               -- Buffer size used in preprocessing step.
+        normalizing_const         -- Normalizing constant used for 
standardizing
+                                     arrays in independent_varname.
 
         
---------------------------------------------------------------------------
         """.format(**locals())
diff --git 
a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
index 89afa20..7a45db2 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
@@ -501,8 +501,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname           VARCHAR,
     independent_varname         VARCHAR,
     buffer_size                 INTEGER,
-    normalizing_const           DOUBLE PRECISION,
-    one_hot_encode_int_dep_var  BOOLEAN
+    normalizing_const           DOUBLE PRECISION
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(utilities, minibatch_preprocessing)
     from utilities.control import MinWarning
@@ -518,21 +517,9 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     output_table            VARCHAR,
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR,
-    buffer_size             INTEGER,
-    normalizing_const       DOUBLE PRECISION
-) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, 
FALSE);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-
-CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl(
-    source_table            VARCHAR,
-    output_table            VARCHAR,
-    dependent_varname       VARCHAR,
-    independent_varname     VARCHAR,
     buffer_size             INTEGER
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, 
FALSE);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -542,7 +529,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, 
FALSE);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
index cd0a7e6..31bfdd7 100644
--- 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
+++ 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
@@ -101,6 +101,19 @@ SELECT minibatch_preprocessor_dl(
 SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not 
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM 
minibatch_preprocessor_dl_batch) a;
 SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not 
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM 
minibatch_preprocessor_dl_batch) a;
 
+-- Test summary table
+SELECT assert
+        (
+        source_table        = 'minibatch_preprocessor_dl_input' AND
+        output_table        = 'minibatch_preprocessor_dl_batch' AND
+        dependent_varname   = 'y' AND
+        independent_varname = 'x' AND
+        dependent_vartype   = 'integer' AND
+        class_values        = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12}' AND
+        buffer_size         = 4 AND  -- we sort the class values in python
+        normalizing_const   = 5,
+        'Summary Validation failed. Actual:' || __to_char(summary)
+        ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
 
 -- Test one-hot encoding for dependent_var
 -- test boolean type
@@ -116,6 +129,10 @@ SELECT assert(pg_typeof(dependent_var) = 
'integer[]'::regtype, 'One-hot encode d
 SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode 
dimension') FROM
   minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
+SELECT assert (dependent_vartype   = 'boolean' AND
+               class_values        = '{f,t}',
+               'Summary Validation failed. Actual:' || __to_char(summary)
+              ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;
 
 -- test text type
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
@@ -130,6 +147,10 @@ SELECT assert(pg_typeof(dependent_var) = 
'integer[]'::regtype, 'One-hot encode d
 SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode 
dimension') FROM
   minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
+SELECT assert (dependent_vartype   = 'text' AND
+               class_values        = '{a,b,c}',
+               'Summary Validation failed. Actual:' || __to_char(summary)
+              ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;
 
 -- test double precision type
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
@@ -144,7 +165,10 @@ SELECT assert(pg_typeof(dependent_var) = 
'integer[]'::regtype, 'One-hot encode d
 SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode 
dimension') FROM
   minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST(dependent_var[1:1]) as y FROM 
minibatch_preprocessor_dl_batch) a WHERE buffer_id = 0;
-
+SELECT assert (dependent_vartype   = 'double precision' AND
+               class_values        = '{4.0,4.2,5.0}',
+               'Summary Validation failed. Actual:' || __to_char(summary)
+              ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;
 
 -- test double precision array type
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
@@ -159,6 +183,10 @@ SELECT assert(pg_typeof(dependent_var) = 
'integer[]'::regtype, 'One-hot encode d
 SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode 
dimension') FROM
   minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
 SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot 
encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM 
minibatch_preprocessor_dl_batch) a, (SELECT UNNEST(y4) as y4 FROM 
minibatch_preprocessor_dl_input) b;
+SELECT assert (dependent_vartype   = 'double precision[]' AND
+               class_values        IS NULL,
+               'Summary Validation failed. Actual:' || __to_char(summary)
+              ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;
 
 -- test integer array type
 DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;

Reply via email to