This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new d5c9a97  Minibatch Preprocessor DL: Add optional num_classes param.
d5c9a97 is described below

commit d5c9a97f5860fa207f77d773ad8042d692dc8deb
Author: Nandish Jayaram <[email protected]>
AuthorDate: Fri Mar 29 16:48:52 2019 -0700

    Minibatch Preprocessor DL: Add optional num_classes param.
    
    JIRA: MADLIB-1314
    
    The current `minibatch_preprocessor_dl()` module looks at the input
    table to find the number of distinct categories (class values) for the
    dependent variable, and uses that number as the size of the
    one-hot-encoded array. This could lead to a failure in madlib_keras fit
    function if the `num_classes` defined in the architecture is a number
    greater/different than the size of the one hot encoded array.
    
    This commit adds two functionalities:
    1) A new optional parameter to `minibatch_preprocessor_dl()` that will
    be used to determine the length of the 1-hot encoded vector for the
    dependent var. If the param is set to NULL, the length will be equal to
    the number of distinct class values found in the dataset, else
    num_classes must be greater than equal to the number of distinct class
    values.
    The `class_values` column in the summary table contains an array of
    class values associated with the 1-hot encoded vector. That will have
    NULL as the value for class values that we don't find any representation
    for in the dataset.
    2) We now support NULL as a valid class value for dependent variable.
    
    Closes #361
    
    Co-authored-by: Ekta Khanna <[email protected]>
---
 src/ports/postgres/modules/internal/db_utils.py_in |  13 +-
 .../utilities/minibatch_preprocessing.py_in        |  98 +++++++++--
 .../utilities/minibatch_preprocessing_dl.sql_in    |  87 +++++++++-
 .../test/minibatch_preprocessing_dl.sql_in         |  87 +++++++++-
 .../unit_tests/test_minibatch_preprocessing.py_in  | 186 ++++++++++++++++-----
 5 files changed, 405 insertions(+), 66 deletions(-)

diff --git a/src/ports/postgres/modules/internal/db_utils.py_in 
b/src/ports/postgres/modules/internal/db_utils.py_in
index 37700a8..934107c 100644
--- a/src/ports/postgres/modules/internal/db_utils.py_in
+++ b/src/ports/postgres/modules/internal/db_utils.py_in
@@ -26,7 +26,8 @@ m4_changequote(`<!', `!>')
 QUOTE_DELIMITER="$__madlib__$"
 
 
-def get_distinct_col_levels(source_table, col_name, col_type=None):
+def get_distinct_col_levels(source_table, col_name, col_type=None,
+                            include_nulls=False):
     """
     Add description here
     :return:
@@ -39,10 +40,13 @@ def get_distinct_col_levels(source_table, col_name, 
col_type=None):
     else:
         dep_var_text_patched = col_name
 
+    where_clause = 'WHERE ({0}) is NOT NULL'.format(col_name)
+    if include_nulls:
+        where_clause = ''
     levels = plpy.execute("""
                 SELECT DISTINCT {dep_var_text_patched} AS levels
                 FROM {source_table}
-                WHERE ({col_name}) is NOT NULL
+                {where_clause}
                 """.format(**locals()))
 
     levels = sorted(l["levels"] for l in levels)
@@ -55,10 +59,11 @@ def get_one_hot_encoded_expr(col_name, col_levels):
     the sql function `quote_literal`.
     :param col_name:
     :param col_levels:
+
     :return:
     """
-    one_hot_encoded_expr = ["({0}) = {1}".format(col_name, c)
-                            for c in col_levels]
+    one_hot_encoded_expr = ["({0}) = {1}".format(
+        col_name, c) for c in col_levels]
     return 'ARRAY[{0}]::INTEGER[]'.format(', '.join(one_hot_encoded_expr))
 # 
------------------------------------------------------------------------------
 
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index 8bfa978..32eb3e5 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -86,7 +86,8 @@ class MiniBatchPreProcessor:
         self.to_one_hot_encode = 
self.should_one_hot_encode(one_hot_encode_int_dep_var)
         if self.to_one_hot_encode:
             self.dependent_levels = get_distinct_col_levels(
-            self.source_table, self.dependent_varname, self.dependent_vartype)
+                self.source_table, self.dependent_varname,
+                self.dependent_vartype)
         else:
             self.dependent_levels = None
         self._validate_minibatch_preprocessor_params()
@@ -343,10 +344,10 @@ class MiniBatchPreProcessor:
         plpy.execute(query)
 
 
-class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
+class MiniBatchPreProcessorDL():
     def __init__(self, schema_madlib, source_table, output_table,
                  dependent_varname, independent_varname, buffer_size,
-                 normalizing_const=1.0, **kwargs):
+                 normalizing_const, num_classes, **kwargs):
         self.schema_madlib = schema_madlib
         self.source_table = source_table
         self.output_table = output_table
@@ -354,6 +355,7 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
         self.independent_varname = independent_varname
         self.buffer_size = buffer_size
         self.normalizing_const = normalizing_const if normalizing_const is not 
None else 1.0
+        self.num_classes = num_classes
         self.module_name = "minibatch_preprocessor_DL"
         self.output_summary_table = add_postfix(self.output_table, "_summary")
         self.independent_vartype = get_expr_type(self.independent_varname,
@@ -363,26 +365,87 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
 
         self._validate_args()
         self.num_of_buffers = self._get_num_buffers()
-        self.to_one_hot_encode = True
         if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
             self.dependent_levels = None
         else:
             self.dependent_levels = get_distinct_col_levels(
-                self.source_table, self.dependent_varname, 
self.dependent_vartype)
+                self.source_table, self.dependent_varname,
+                self.dependent_vartype, include_nulls=True)
+            # if any class level was NULL in sql, that would show up as
+            # None in self.dependent_levels. Replace all None with NULL
+            # in the list.
+            self.dependent_levels = ['NULL' if level is None else level
+                for level in self.dependent_levels]
+            self._validate_num_classes()
+        # Find the number of padded zeros to include in 1-hot vector
+        self.padding_size = 0
+        # Try computing padding_size after running all necessary validations.
+        if self.num_classes and self.dependent_levels:
+            self.padding_size = self.num_classes - len(self.dependent_levels)
+
+    def _validate_num_classes(self):
+        if self.num_classes is not None and \
+            self.num_classes < len(self.dependent_levels):
+            plpy.error("{0}: Invalid num_classes value specified. It must "\
+                "be equal to or greater than distinct class values found "\
+                "in table ({1}).".format(
+                    self.module_name, len(self.dependent_levels)))
+
+    def get_one_hot_encoded_dep_var_expr(self):
+        """
+        :param dependent_varname: Name of the dependent variable
+        :param num_classes: Number of class values to consider in 1-hot
+        :return:
+            This function returns a tuple of
+            1. A string with transformed dependent varname depending on it's 
type
+            2. All the distinct dependent class levels encoded as a string
+
+            If dep_type == numeric[] , do not encode
+                    1. dependent_varname = rings
+                        transformed_value = ARRAY[rings]
+                    2. dependent_varname = ARRAY[a, b, c]
+                        transformed_value = ARRAY[a, b, c]
+            else if dep_type in ("text", "boolean"), encode:
+                    3. dependent_varname = rings (encoding)
+                        transformed_value = ARRAY[rings=1, rings=2, rings=3]
+        """
+        # Assuming the input NUMERIC[] is already one_hot_encoded,
+        # so casting to INTEGER[]
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+            return "{0}::INTEGER[]".format(self.dependent_varname)
+
+        # For DL use case, we want to allow NULL as a valid class value,
+        # so the query must have 'IS NOT DISTINCT FROM' instead of '='
+        # like in the generic get_one_hot_encoded_expr() defined in
+        # db_utils.py_in. We also have this optional 'num_classes' param
+        # that affects the logic of 1-hot encoding. Since this is very
+        # specific to minibatch_preprocessing_dl for now, let's keep
+        # it here instead of refactoring it out to a generic helper function.
+        one_hot_encoded_expr = ["({0}) IS NOT DISTINCT FROM {1}".format(
+            self.dependent_varname, c) for c in self.dependent_levels]
+        if self.num_classes:
+            one_hot_encoded_expr.extend(['false'
+                for i in range(self.padding_size)])
+        return 'ARRAY[{0}]::INTEGER[]'.format(
+            ', '.join(one_hot_encoded_expr))
 
     def minibatch_preprocessor_dl(self):
         # Create a temp table that has independent var normalized.
         norm_tbl = unique_string(desp='normalized')
 
-        dep_var_array_expr = self.get_dep_var_array_expr()
-        # Assuming the input NUMERIC[] is already one_hot_encoded, so casting 
to INTEGER[]
-        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
-            dep_var_array_expr = dep_var_array_expr + '::INTEGER[]'
+        # Always one-hot encode the dependent var. For now, we are assuming
+        # that minibatch_preprocessor_dl will be used only for deep
+        # learning and mostly for classification. So make a strong
+        # assumption that it is only for classification, so one-hot
+        # encode the dep var, unless it's already a numeric array in
+        # which case we assume it's already one-hot encoded.
+        one_hot_dep_var_array_expr = \
+            self.get_one_hot_encoded_dep_var_expr()
         scalar_mult_sql = """
             CREATE TEMP TABLE {norm_tbl} AS
             SELECT {self.schema_madlib}.array_scalar_mult(
                 {self.independent_varname}::REAL[], 
(1/{self.normalizing_const})::REAL) AS x_norm,
-                {dep_var_array_expr} AS y,
+                {one_hot_dep_var_array_expr} AS y,
                 row_number() over() AS row_id
             FROM {self.source_table} order by random()
             """.format(**locals())
@@ -416,6 +479,11 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
     def _create_output_summary_table(self):
         class_level_str='NULL::TEXT'
         if self.dependent_levels:
+            # Update dependent_levels to include NULL when
+            # num_classes > len(self.dependent_levels)
+            if self.num_classes:
+                self.dependent_levels.extend(['NULL'
+                    for i in range(self.padding_size)])
             class_level_str=py_list_to_sql_string(
                 self.dependent_levels, array_type=self.dependent_vartype,
                 long_format=True)
@@ -741,6 +809,11 @@ class MiniBatchDocumentation:
         The normalizing constant is parameterized, and can be specified based
         on the kind of image data used.
 
+        An optional param named num_classes can be used to specify the length
+        of the one-hot encoded array for the dependent variable. This value if
+        specified must be greater than equal to the total number of distinct
+        class values found in the input table.
+
         For more details on function usage:
         SELECT {schema_madlib}.{method}('usage')
         """.format(**locals())
@@ -762,6 +835,11 @@ class MiniBatchDocumentation:
             normalizing_const      -- DOUBLE PRECISON. Default 1.0. The
                                       normalizing constant to use for
                                       standardizing arrays in 
independent_varname.
+            num_classes            -- INTEGER. Default NULL. Number of class 
labels
+                                      to be considered for 1-hot encoding. If 
NULL,
+                                      the 1-hot encoded array length will be 
equal to
+                                      the number of distinct class values 
found in the
+                                      input table.
         );
 
 
diff --git 
a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
index 67e216d..3fe17d0 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
@@ -55,7 +55,8 @@ minibatch_preprocessor_dl( source_table,
                            dependent_varname,
                            independent_varname,
                            buffer_size,
-                           normalizing_const
+                           normalizing_const,
+                           num_classes
                         )
 </pre>
 
@@ -105,6 +106,12 @@ minibatch_preprocessor_dl( source_table,
   each value in the independent_varname array by.  For example,
   you may need to use 255 for this value if the image data is in the form 
0-255.
   </dd>
+
+  <dt>num_classes (optional)</dt>
+  <dd>INTEGER, default: NULL. Number of class labels to be considered for 1-hot
+  encoding. If NULL, the 1-hot encoded array length will be equal to the number
+  of distinct class values found in the input table.
+  </dd>
 </dl>
 
 <b>Output tables</b>
@@ -118,9 +125,14 @@ minibatch_preprocessor_dl( source_table,
       </tr>
       <tr>
         <th>dependent_var</th>
-        <td>ANYARRAY[]. Packed array of dependent variables. The type
-        of the array is the same as the type of the dependent variable from
-        the source table.
+        <td>ANYARRAY[]. Packed array of dependent variables.
+        The dependent variable is always one-hot encoded as an
+        INTEGER[] array. For now, we are assuming that
+        minibatch_preprocessor_dl will be used
+        only for classification problems using deep learning. So
+        the dependent variable is one-hot encoded, unless it's already a
+        numeric array in which case we assume it's already one-hot
+        encoded and just cast it to an INTEGER[] array.
         </td>
       </tr>
       <tr>
@@ -447,6 +459,54 @@ class_values        | {bird,cat,dog}
 buffer_size         | 10
 </pre>
 
+-#  Run the preprocessor for image data with num_classes greater than 3 
(distinct class values found in table):
+<pre class="example">
+DROP TABLE IF EXISTS image_data_packed, image_data_packed_summary;
+SELECT madlib.minibatch_preprocessor_dl('image_data',         -- Source table
+                                        'image_data_packed',  -- Output table
+                                        'species',            -- Dependent 
variable
+                                        'rgb',                -- Independent 
variable
+                                        NULL,                 -- Buffer size
+                                        255,                  -- Normalizing 
constant
+                                        5                     -- Number of 
desired class values
+                                        );
+</pre>
+Here is a sample of the packed output table with the padded 1-hot vector:
+<pre class="example">
+\\x on
+SELECT * FROM image_data_packed ORDER BY buffer_id;
+</pre>
+<pre class="result">
+-[ RECORD 1 
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.639216,0.517647,0.87451,0.0862745,0.784314,...},...}
+dependent_var   | {{0,0,1,0,0},{1,0,0,0,0},{1,0,0,0,0},{1,0,0,0,0},...}
+buffer_id       | 0
+-[ RECORD 2 
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.866667,0.0666667,0.803922,0.239216,0.741176,...},...}
+dependent_var   | {{0,0,1,0,0},{0,0,1,0,0},{0,1,0,0,0},{0,1,0,0,0},...}
+buffer_id       | 1
+-[ RECORD 3 
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.184314,0.87451,0.227451,0.466667,0.203922,...},...}
+dependent_var   | {{1,0,0,0,0},{0,1,0,0,0},{1,0,0,0,0},{0,0,1,0,0},...}
+buffer_id       | 2
+</pre>
+Review the output summary table:
+<pre class="example">
+\\x on
+SELECT * FROM image_data_packed_summary;
+</pre>
+<pre class="result">
+-[ RECORD 1 ]-------+-------------------------
+source_table        | image_data
+output_table        | image_data_packed
+dependent_varname   | species
+independent_varname | rgb
+dependent_vartype   | text
+class_values        | {bird,cat,dog,NULL,NULL}
+buffer_size         | 18
+normalizing_const   | 255.0
+</pre>
+
 @anchor related
 @par Related Topics
 
@@ -462,7 +522,8 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname           VARCHAR,
     independent_varname         VARCHAR,
     buffer_size                 INTEGER,
-    normalizing_const           DOUBLE PRECISION
+    normalizing_const           DOUBLE PRECISION,
+    num_classes                 INTEGER
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(utilities, minibatch_preprocessing)
     from utilities.control import MinWarning
@@ -478,9 +539,21 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     output_table            VARCHAR,
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR,
+    buffer_size             INTEGER,
+    normalizing_const       DOUBLE PRECISION
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl(
+    source_table            VARCHAR,
+    output_table            VARCHAR,
+    dependent_varname       VARCHAR,
+    independent_varname     VARCHAR,
     buffer_size             INTEGER
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0, 
NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -490,7 +563,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.minibatch_preprocessor_dl(
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0, 
NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
diff --git 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
index 31bfdd7..0e64c07 100644
--- 
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
+++ 
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
@@ -95,11 +95,16 @@ SELECT minibatch_preprocessor_dl(
   'y',
   'x',
   4,
-  5);
+  5,
+  16 -- num_classes
+  );
 
 -- Test that indepdendent vars get divided by 5, by verifying min value goes 
from 1 to 0.2, and max value from 233 to 46.6
 SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not 
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM 
minibatch_preprocessor_dl_batch) a;
 SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not 
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM 
minibatch_preprocessor_dl_batch) a;
+-- Test that 1-hot encoded array is of length 16 (num_classes)
+SELECT assert(array_upper(dependent_var, 2) = 16, 'Incorrect one-hot encode 
dimension with num_classes') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
 
 -- Test summary table
 SELECT assert
@@ -109,7 +114,7 @@ SELECT assert
         dependent_varname   = 'y' AND
         independent_varname = 'x' AND
         dependent_vartype   = 'integer' AND
-        class_values        = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12}' AND
+        class_values        = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12,NULL,NULL}' 
AND
         buffer_size         = 4 AND  -- we sort the class values in python
         normalizing_const   = 5,
         'Summary Validation failed. Actual:' || __to_char(summary)
@@ -205,3 +210,81 @@ SELECT assert (dependent_vartype   = 'integer[]' AND
                class_values        IS NULL,
                'Summary Validation failed. Actual:' || __to_char(summary)
               ) from (select * from minibatch_preprocessor_dl_batch_summary) 
summary;
+
+-- Test cases with NULL in class values
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_input_null;
+CREATE TABLE minibatch_preprocessor_dl_input_null(id serial, x double 
precision[], label TEXT);
+INSERT INTO minibatch_preprocessor_dl_input_null(x, label) VALUES
+(ARRAY[1,2,3,4,5,6], 'a'),
+(ARRAY[11,2,3,4,5,6], 'a'),
+(ARRAY[11,22,33,4,5,6], NULL),
+(ARRAY[11,22,33,44,5,6], 'a'),
+(ARRAY[11,22,33,44,65,6], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], NULL),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], NULL),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,144,65,56], 'b'),
+(ARRAY[11,22,233,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], NULL);
+
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input_null',
+  'minibatch_preprocessor_dl_batch',
+  'label',
+  'x',
+  4,
+  5,
+  5 -- num_classes
+  );
+-- Test summary table if class_values has NULL as a legitimate
+-- class label, and also two other NULLs because num_classes=5
+-- but table has only 3 distinct class labels (including NULL)
+SELECT assert
+        (
+        class_values        = '{NULL,a,b,NULL,NULL}',
+        'Summary Validation failed with NULL data. Actual:' || 
__to_char(summary)
+        ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
+
+SELECT assert(array_upper(dependent_var, 2) = 5, 'Incorrect one-hot encode 
dimension with NULL data') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+
+-- Test the content of 1-hot encoded dep var when NULL is the
+-- class label.
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_input_null;
+CREATE TABLE minibatch_preprocessor_dl_input_null(id serial, x double 
precision[], label TEXT);
+INSERT INTO minibatch_preprocessor_dl_input_null(x, label) VALUES
+(ARRAY[11,22,33,4,5,6], NULL);
+
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch, 
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+  'minibatch_preprocessor_dl_input_null',
+  'minibatch_preprocessor_dl_batch',
+  'label',
+  'x',
+  4,
+  5,
+  3 -- num_classes
+  );
+
+-- class_values must be '{NULL,NULL,NULL}' where the first NULL
+-- is for the class label seen in data, and the other two NULLs
+-- are added as num_classes=3.
+SELECT assert
+        (
+        class_values        = '{NULL,NULL,NULL}',
+        'Summary Validation failed with NULL data. Actual:' || 
__to_char(summary)
+        ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
+
+SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode 
dimension with NULL data') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+-- NULL is treated as a class label, so it should show '1' for the
+-- first index
+SELECT assert(dependent_var = '{{1,0,0}}', 'Incorrect one-hot encode dimension 
with NULL data') FROM
+  minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
diff --git 
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
 
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
index 2fdf082..e213562 100644
--- 
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
+++ 
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
@@ -343,74 +343,92 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
         self.default_dep_var = "depvar"
         self.default_ind_var = "indvar"
         self.default_buffer_size = 5
-        #self.dependent_vartype = "integer[]"
-        #self.independent_vartype = "integer[]"
+        self.default_normalizing_const = 1.0
+        self.default_num_classes = None
 
         import utilities.minibatch_preprocessing
         self.module = utilities.minibatch_preprocessing
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
         self.module.validate_module_input_params = Mock()
         self.module.get_distinct_col_levels = Mock(return_value = [0,22,100])
-        self.subject = 
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                                self.default_source_table,
-                                                self.default_output_table,
-                                                self.default_dep_var,
-                                                self.default_ind_var,
-                                                self.default_buffer_size)
+        self.subject = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            self.default_source_table,
+            self.default_output_table,
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            self.default_num_classes)
 
     def tearDown(self):
         self.module_patcher.stop()
 
     def test_minibatch_preprocessor_dl_executes_query(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
-        preprocessor_obj = 
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                                             "input",
-                                                             "out",
-                                                             
self.default_dep_var,
-                                                             
self.default_ind_var,
-                                                             
self.default_buffer_size)
+        preprocessor_obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            "input",
+            "out",
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            self.default_num_classes)
         preprocessor_obj.minibatch_preprocessor_dl()
 
     def test_minibatch_preprocessor_null_buffer_size_executes_query(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
-        preprocessor_obj = 
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                                             "input",
-                                                             "out",
-                                                             
self.default_dep_var,
-                                                             
self.default_ind_var,
-                                                             None)
+        preprocessor_obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            "input",
+            "out",
+            self.default_dep_var,
+            self.default_ind_var,
+            None,
+            self.default_normalizing_const,
+            self.default_num_classes)
         
self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size = 
Mock(return_value = 5)
         preprocessor_obj.minibatch_preprocessor_dl()
 
     def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
         with self.assertRaises(plpy.PLPYException):
-            self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                              self.default_source_table,
-                                              self.default_output_table,
-                                              "y1,y2",
-                                              self.default_ind_var,
-                                              self.default_buffer_size)
+            self.module.MiniBatchPreProcessorDL(
+                self.default_schema_madlib,
+                self.default_source_table,
+                self.default_output_table,
+                "y1,y2",
+                self.default_ind_var,
+                self.default_buffer_size,
+                self.default_normalizing_const,
+                self.default_num_classes)
 
     def test_minibatch_preprocessor_multiple_indep_var_raises_exception(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
         with self.assertRaises(plpy.PLPYException):
-            self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                              self.default_source_table,
-                                              self.default_output_table,
-                                              self.default_dep_var,
-                                              "x1,x2",
-                                              self.default_buffer_size)
+            self.module.MiniBatchPreProcessorDL(
+                self.default_schema_madlib,
+                self.default_source_table,
+                self.default_output_table,
+                self.default_dep_var,
+                "x1,x2",
+                self.default_buffer_size,
+                self.default_normalizing_const,
+                self.default_num_classes)
 
     def test_minibatch_preprocessor_buffer_size_zero_fails(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
         with self.assertRaises(plpy.PLPYException):
-            self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
-                                              self.default_source_table,
-                                              self.default_output_table,
-                                              self.default_dep_var,
-                                              self.default_ind_var,
-                                              0)
+            self.module.MiniBatchPreProcessorDL(
+                self.default_schema_madlib,
+                self.default_source_table,
+                self.default_output_table,
+                self.default_dep_var,
+                self.default_ind_var,
+                0,
+                self.default_normalizing_const,
+                self.default_num_classes)
 
     def test_minibatch_preprocessor_negative_buffer_size_fails(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
@@ -420,7 +438,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
                                               self.default_output_table,
                                               self.default_dep_var,
                                               self.default_ind_var,
-                                              -1)
+                                              -1,
+                                              self.default_normalizing_const,
+                                              self.default_num_classes)
 
     def 
test_minibatch_preprocessor_invalid_indep_vartype_raises_exception(self):
         self.module.get_expr_type = Mock(side_effect = ['integer', 
'integer[]'])
@@ -430,7 +450,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
                                                 self.default_output_table,
                                                 self.default_dep_var,
                                                 self.default_ind_var,
-                                                self.default_buffer_size)
+                                                self.default_buffer_size,
+                                                self.default_normalizing_const,
+                                                self.default_num_classes)
 
     def test_minibatch_preprocessor_invalid_dep_vartype_raises_exception(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text[]'])
@@ -440,7 +462,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
                                                 self.default_output_table,
                                                 self.default_dep_var,
                                                 self.default_ind_var,
-                                                self.default_buffer_size)
+                                                self.default_buffer_size,
+                                                self.default_normalizing_const,
+                                                self.default_num_classes)
 
     def test_minibatch_preprocessor_normalizing_const_zero_fails(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
@@ -451,7 +475,8 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
                                                 self.default_dep_var,
                                                 self.default_ind_var,
                                                 self.default_buffer_size,
-                                                0)
+                                                0,
+                                                self.default_num_classes)
 
     def test_minibatch_preprocessor_negative_normalizing_const_fails(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
@@ -462,7 +487,82 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
                                                 self.default_dep_var,
                                                 self.default_ind_var,
                                                 self.default_buffer_size,
-                                                -1)
+                                                -1,
+                                                self.default_num_classes)
+
+    def test_get_one_hot_encoded_dep_var_expr_null_val(self):
+        self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+        self.module.get_distinct_col_levels = Mock(return_value = ["NULL", 
"'a'"])
+        obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            self.default_source_table,
+            self.default_output_table,
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            self.default_num_classes)
+        dep_var_array_expr = obj.get_one_hot_encoded_dep_var_expr()
+        self.assertEqual("array[({0}) is not distinct from null, ({0}) is not 
distinct from 'a']::integer[]".
+                     format(self.default_dep_var),
+                     dep_var_array_expr.lower())
+
+    def test_get_one_hot_encoded_dep_var_expr_numeric_array_val(self):
+        self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
+        obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            self.default_source_table,
+            self.default_output_table,
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            self.default_num_classes)
+        dep_var_array_expr = obj.get_one_hot_encoded_dep_var_expr()
+        self.assertEqual("{0}::integer[]".
+                     format(self.default_dep_var),
+                     dep_var_array_expr.lower())
+
+    def test_validate_num_classes_none(self):
+        self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+        obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            self.default_source_table,
+            self.default_output_table,
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            None)
+        self.assertEqual(0, obj.padding_size)
+
+    def test_validate_num_classes_greater(self):
+        self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+        self.module.dependent_levels = Mock(return_value = ["'a'", "'b'", 
"'c'"])
+        obj = self.module.MiniBatchPreProcessorDL(
+            self.default_schema_madlib,
+            self.default_source_table,
+            self.default_output_table,
+            self.default_dep_var,
+            self.default_ind_var,
+            self.default_buffer_size,
+            self.default_normalizing_const,
+            5)
+        self.assertEqual(2, obj.padding_size)
+
+    def test_validate_num_classes_lesser(self):
+        self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+        self.module.dependent_levels = Mock(return_value = ["'a'", "'b'", 
"'c'"])
+        with self.assertRaises(plpy.PLPYException):
+            obj = self.module.MiniBatchPreProcessorDL(
+                self.default_schema_madlib,
+                self.default_source_table,
+                self.default_output_table,
+                self.default_dep_var,
+                self.default_ind_var,
+                self.default_buffer_size,
+                self.default_normalizing_const,
+                2)
 
 if __name__ == '__main__':
     unittest.main()

Reply via email to