This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new d5c9a97 Minibatch Preprocessor DL: Add optional num_classes param.
d5c9a97 is described below
commit d5c9a97f5860fa207f77d773ad8042d692dc8deb
Author: Nandish Jayaram <[email protected]>
AuthorDate: Fri Mar 29 16:48:52 2019 -0700
Minibatch Preprocessor DL: Add optional num_classes param.
JIRA: MADLIB-1314
The current `minibatch_preprocessor_dl()` module looks at the input
table to find the number of distinct categories (class values) for the
dependent variable, and uses that number as the size of the
one-hot-encoded array. This could lead to a failure in madlib_keras fit
function if the `num_classes` defined in the architecture is a number
greater/different than the size of the one hot encoded array.
This commit adds two functionalities:
1) A new optional parameter to `minibatch_preprocessor_dl()` that will
be used to determine the length of the 1-hot encoded vector for the
dependent var. If the param is set to NULL, the length will be equal to
the number of distinct class values found in the dataset, else
num_classes must be greater than equal to the number of distinct class
values.
The `class_values` column in the summary table contains an array of
class values associated with the 1-hot encoded vector. That will have
NULL as the value for class values that we don't find any representation
for in the dataset.
2) We now support NULL as a valid class value for dependent variable.
Closes #361
Co-authored-by: Ekta Khanna <[email protected]>
---
src/ports/postgres/modules/internal/db_utils.py_in | 13 +-
.../utilities/minibatch_preprocessing.py_in | 98 +++++++++--
.../utilities/minibatch_preprocessing_dl.sql_in | 87 +++++++++-
.../test/minibatch_preprocessing_dl.sql_in | 87 +++++++++-
.../unit_tests/test_minibatch_preprocessing.py_in | 186 ++++++++++++++++-----
5 files changed, 405 insertions(+), 66 deletions(-)
diff --git a/src/ports/postgres/modules/internal/db_utils.py_in
b/src/ports/postgres/modules/internal/db_utils.py_in
index 37700a8..934107c 100644
--- a/src/ports/postgres/modules/internal/db_utils.py_in
+++ b/src/ports/postgres/modules/internal/db_utils.py_in
@@ -26,7 +26,8 @@ m4_changequote(`<!', `!>')
QUOTE_DELIMITER="$__madlib__$"
-def get_distinct_col_levels(source_table, col_name, col_type=None):
+def get_distinct_col_levels(source_table, col_name, col_type=None,
+ include_nulls=False):
"""
Add description here
:return:
@@ -39,10 +40,13 @@ def get_distinct_col_levels(source_table, col_name,
col_type=None):
else:
dep_var_text_patched = col_name
+ where_clause = 'WHERE ({0}) is NOT NULL'.format(col_name)
+ if include_nulls:
+ where_clause = ''
levels = plpy.execute("""
SELECT DISTINCT {dep_var_text_patched} AS levels
FROM {source_table}
- WHERE ({col_name}) is NOT NULL
+ {where_clause}
""".format(**locals()))
levels = sorted(l["levels"] for l in levels)
@@ -55,10 +59,11 @@ def get_one_hot_encoded_expr(col_name, col_levels):
the sql function `quote_literal`.
:param col_name:
:param col_levels:
+
:return:
"""
- one_hot_encoded_expr = ["({0}) = {1}".format(col_name, c)
- for c in col_levels]
+ one_hot_encoded_expr = ["({0}) = {1}".format(
+ col_name, c) for c in col_levels]
return 'ARRAY[{0}]::INTEGER[]'.format(', '.join(one_hot_encoded_expr))
#
------------------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index 8bfa978..32eb3e5 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -86,7 +86,8 @@ class MiniBatchPreProcessor:
self.to_one_hot_encode =
self.should_one_hot_encode(one_hot_encode_int_dep_var)
if self.to_one_hot_encode:
self.dependent_levels = get_distinct_col_levels(
- self.source_table, self.dependent_varname, self.dependent_vartype)
+ self.source_table, self.dependent_varname,
+ self.dependent_vartype)
else:
self.dependent_levels = None
self._validate_minibatch_preprocessor_params()
@@ -343,10 +344,10 @@ class MiniBatchPreProcessor:
plpy.execute(query)
-class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
+class MiniBatchPreProcessorDL():
def __init__(self, schema_madlib, source_table, output_table,
dependent_varname, independent_varname, buffer_size,
- normalizing_const=1.0, **kwargs):
+ normalizing_const, num_classes, **kwargs):
self.schema_madlib = schema_madlib
self.source_table = source_table
self.output_table = output_table
@@ -354,6 +355,7 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
self.independent_varname = independent_varname
self.buffer_size = buffer_size
self.normalizing_const = normalizing_const if normalizing_const is not
None else 1.0
+ self.num_classes = num_classes
self.module_name = "minibatch_preprocessor_DL"
self.output_summary_table = add_postfix(self.output_table, "_summary")
self.independent_vartype = get_expr_type(self.independent_varname,
@@ -363,26 +365,87 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
self._validate_args()
self.num_of_buffers = self._get_num_buffers()
- self.to_one_hot_encode = True
if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
self.dependent_levels = None
else:
self.dependent_levels = get_distinct_col_levels(
- self.source_table, self.dependent_varname,
self.dependent_vartype)
+ self.source_table, self.dependent_varname,
+ self.dependent_vartype, include_nulls=True)
+ # if any class level was NULL in sql, that would show up as
+ # None in self.dependent_levels. Replace all None with NULL
+ # in the list.
+ self.dependent_levels = ['NULL' if level is None else level
+ for level in self.dependent_levels]
+ self._validate_num_classes()
+ # Find the number of padded zeros to include in 1-hot vector
+ self.padding_size = 0
+ # Try computing padding_size after running all necessary validations.
+ if self.num_classes and self.dependent_levels:
+ self.padding_size = self.num_classes - len(self.dependent_levels)
+
+ def _validate_num_classes(self):
+ if self.num_classes is not None and \
+ self.num_classes < len(self.dependent_levels):
+ plpy.error("{0}: Invalid num_classes value specified. It must "\
+ "be equal to or greater than distinct class values found "\
+ "in table ({1}).".format(
+ self.module_name, len(self.dependent_levels)))
+
+ def get_one_hot_encoded_dep_var_expr(self):
+ """
+ :param dependent_varname: Name of the dependent variable
+ :param num_classes: Number of class values to consider in 1-hot
+ :return:
+ This function returns a tuple of
+ 1. A string with transformed dependent varname depending on it's
type
+ 2. All the distinct dependent class levels encoded as a string
+
+ If dep_type == numeric[] , do not encode
+ 1. dependent_varname = rings
+ transformed_value = ARRAY[rings]
+ 2. dependent_varname = ARRAY[a, b, c]
+ transformed_value = ARRAY[a, b, c]
+ else if dep_type in ("text", "boolean"), encode:
+ 3. dependent_varname = rings (encoding)
+ transformed_value = ARRAY[rings=1, rings=2, rings=3]
+ """
+ # Assuming the input NUMERIC[] is already one_hot_encoded,
+ # so casting to INTEGER[]
+ if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+ return "{0}::INTEGER[]".format(self.dependent_varname)
+
+ # For DL use case, we want to allow NULL as a valid class value,
+ # so the query must have 'IS NOT DISTINCT FROM' instead of '='
+ # like in the generic get_one_hot_encoded_expr() defined in
+ # db_utils.py_in. We also have this optional 'num_classes' param
+ # that affects the logic of 1-hot encoding. Since this is very
+ # specific to minibatch_preprocessing_dl for now, let's keep
+ # it here instead of refactoring it out to a generic helper function.
+ one_hot_encoded_expr = ["({0}) IS NOT DISTINCT FROM {1}".format(
+ self.dependent_varname, c) for c in self.dependent_levels]
+ if self.num_classes:
+ one_hot_encoded_expr.extend(['false'
+ for i in range(self.padding_size)])
+ return 'ARRAY[{0}]::INTEGER[]'.format(
+ ', '.join(one_hot_encoded_expr))
def minibatch_preprocessor_dl(self):
# Create a temp table that has independent var normalized.
norm_tbl = unique_string(desp='normalized')
- dep_var_array_expr = self.get_dep_var_array_expr()
- # Assuming the input NUMERIC[] is already one_hot_encoded, so casting
to INTEGER[]
- if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
- dep_var_array_expr = dep_var_array_expr + '::INTEGER[]'
+ # Always one-hot encode the dependent var. For now, we are assuming
+ # that minibatch_preprocessor_dl will be used only for deep
+ # learning and mostly for classification. So make a strong
+ # assumption that it is only for classification, so one-hot
+ # encode the dep var, unless it's already a numeric array in
+ # which case we assume it's already one-hot encoded.
+ one_hot_dep_var_array_expr = \
+ self.get_one_hot_encoded_dep_var_expr()
scalar_mult_sql = """
CREATE TEMP TABLE {norm_tbl} AS
SELECT {self.schema_madlib}.array_scalar_mult(
{self.independent_varname}::REAL[],
(1/{self.normalizing_const})::REAL) AS x_norm,
- {dep_var_array_expr} AS y,
+ {one_hot_dep_var_array_expr} AS y,
row_number() over() AS row_id
FROM {self.source_table} order by random()
""".format(**locals())
@@ -416,6 +479,11 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
def _create_output_summary_table(self):
class_level_str='NULL::TEXT'
if self.dependent_levels:
+ # Update dependent_levels to include NULL when
+ # num_classes > len(self.dependent_levels)
+ if self.num_classes:
+ self.dependent_levels.extend(['NULL'
+ for i in range(self.padding_size)])
class_level_str=py_list_to_sql_string(
self.dependent_levels, array_type=self.dependent_vartype,
long_format=True)
@@ -741,6 +809,11 @@ class MiniBatchDocumentation:
The normalizing constant is parameterized, and can be specified based
on the kind of image data used.
+ An optional param named num_classes can be used to specify the length
+ of the one-hot encoded array for the dependent variable. This value if
+ specified must be greater than equal to the total number of distinct
+ class values found in the input table.
+
For more details on function usage:
SELECT {schema_madlib}.{method}('usage')
""".format(**locals())
@@ -762,6 +835,11 @@ class MiniBatchDocumentation:
normalizing_const -- DOUBLE PRECISON. Default 1.0. The
normalizing constant to use for
standardizing arrays in
independent_varname.
+ num_classes -- INTEGER. Default NULL. Number of class
labels
+ to be considered for 1-hot encoding. If
NULL,
+ the 1-hot encoded array length will be
equal to
+ the number of distinct class values
found in the
+ input table.
);
diff --git
a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
index 67e216d..3fe17d0 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing_dl.sql_in
@@ -55,7 +55,8 @@ minibatch_preprocessor_dl( source_table,
dependent_varname,
independent_varname,
buffer_size,
- normalizing_const
+ normalizing_const,
+ num_classes
)
</pre>
@@ -105,6 +106,12 @@ minibatch_preprocessor_dl( source_table,
each value in the independent_varname array by. For example,
you may need to use 255 for this value if the image data is in the form
0-255.
</dd>
+
+ <dt>num_classes (optional)</dt>
+ <dd>INTEGER, default: NULL. Number of class labels to be considered for 1-hot
+ encoding. If NULL, the 1-hot encoded array length will be equal to the number
+ of distinct class values found in the input table.
+ </dd>
</dl>
<b>Output tables</b>
@@ -118,9 +125,14 @@ minibatch_preprocessor_dl( source_table,
</tr>
<tr>
<th>dependent_var</th>
- <td>ANYARRAY[]. Packed array of dependent variables. The type
- of the array is the same as the type of the dependent variable from
- the source table.
+ <td>ANYARRAY[]. Packed array of dependent variables.
+ The dependent variable is always one-hot encoded as an
+ INTEGER[] array. For now, we are assuming that
+ minibatch_preprocessor_dl will be used
+ only for classification problems using deep learning. So
+ the dependent variable is one-hot encoded, unless it's already a
+ numeric array in which case we assume it's already one-hot
+ encoded and just cast it to an INTEGER[] array.
</td>
</tr>
<tr>
@@ -447,6 +459,54 @@ class_values | {bird,cat,dog}
buffer_size | 10
</pre>
+-# Run the preprocessor for image data with num_classes greater than 3
(distinct class values found in table):
+<pre class="example">
+DROP TABLE IF EXISTS image_data_packed, image_data_packed_summary;
+SELECT madlib.minibatch_preprocessor_dl('image_data', -- Source table
+ 'image_data_packed', -- Output table
+ 'species', -- Dependent
variable
+ 'rgb', -- Independent
variable
+ NULL, -- Buffer size
+ 255, -- Normalizing
constant
+ 5 -- Number of
desired class values
+ );
+</pre>
+Here is a sample of the packed output table with the padded 1-hot vector:
+<pre class="example">
+\\x on
+SELECT * FROM image_data_packed ORDER BY buffer_id;
+</pre>
+<pre class="result">
+-[ RECORD 1
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.639216,0.517647,0.87451,0.0862745,0.784314,...},...}
+dependent_var | {{0,0,1,0,0},{1,0,0,0,0},{1,0,0,0,0},{1,0,0,0,0},...}
+buffer_id | 0
+-[ RECORD 2
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.866667,0.0666667,0.803922,0.239216,0.741176,...},...}
+dependent_var | {{0,0,1,0,0},{0,0,1,0,0},{0,1,0,0,0},{0,1,0,0,0},...}
+buffer_id | 1
+-[ RECORD 3
]---+---------------------------------------------------------------------------------------------------------------------
+independent_var | {{0.184314,0.87451,0.227451,0.466667,0.203922,...},...}
+dependent_var | {{1,0,0,0,0},{0,1,0,0,0},{1,0,0,0,0},{0,0,1,0,0},...}
+buffer_id | 2
+</pre>
+Review the output summary table:
+<pre class="example">
+\\x on
+SELECT * FROM image_data_packed_summary;
+</pre>
+<pre class="result">
+-[ RECORD 1 ]-------+-------------------------
+source_table | image_data
+output_table | image_data_packed
+dependent_varname | species
+independent_varname | rgb
+dependent_vartype | text
+class_values | {bird,cat,dog,NULL,NULL}
+buffer_size | 18
+normalizing_const | 255.0
+</pre>
+
@anchor related
@par Related Topics
@@ -462,7 +522,8 @@ CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.minibatch_preprocessor_dl(
dependent_varname VARCHAR,
independent_varname VARCHAR,
buffer_size INTEGER,
- normalizing_const DOUBLE PRECISION
+ normalizing_const DOUBLE PRECISION,
+ num_classes INTEGER
) RETURNS VOID AS $$
PythonFunctionBodyOnly(utilities, minibatch_preprocessing)
from utilities.control import MinWarning
@@ -478,9 +539,21 @@ CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.minibatch_preprocessor_dl(
output_table VARCHAR,
dependent_varname VARCHAR,
independent_varname VARCHAR,
+ buffer_size INTEGER,
+ normalizing_const DOUBLE PRECISION
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, $6, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor_dl(
+ source_table VARCHAR,
+ output_table VARCHAR,
+ dependent_varname VARCHAR,
+ independent_varname VARCHAR,
buffer_size INTEGER
) RETURNS VOID AS $$
- SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0);
+ SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, $5, 1.0,
NULL);
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -490,7 +563,7 @@ CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.minibatch_preprocessor_dl(
dependent_varname VARCHAR,
independent_varname VARCHAR
) RETURNS VOID AS $$
- SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0);
+ SELECT MADLIB_SCHEMA.minibatch_preprocessor_dl($1, $2, $3, $4, NULL, 1.0,
NULL);
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
index 31bfdd7..0e64c07 100644
---
a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
+++
b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing_dl.sql_in
@@ -95,11 +95,16 @@ SELECT minibatch_preprocessor_dl(
'y',
'x',
4,
- 5);
+ 5,
+ 16 -- num_classes
+ );
-- Test that indepdendent vars get divided by 5, by verifying min value goes
from 1 to 0.2, and max value from 233 to 46.6
SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM
minibatch_preprocessor_dl_batch) a;
SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not
normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM
minibatch_preprocessor_dl_batch) a;
+-- Test that 1-hot encoded array is of length 16 (num_classes)
+SELECT assert(array_upper(dependent_var, 2) = 16, 'Incorrect one-hot encode
dimension with num_classes') FROM
+ minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
-- Test summary table
SELECT assert
@@ -109,7 +114,7 @@ SELECT assert
dependent_varname = 'y' AND
independent_varname = 'x' AND
dependent_vartype = 'integer' AND
- class_values = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12}' AND
+ class_values = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12,NULL,NULL}'
AND
buffer_size = 4 AND -- we sort the class values in python
normalizing_const = 5,
'Summary Validation failed. Actual:' || __to_char(summary)
@@ -205,3 +210,81 @@ SELECT assert (dependent_vartype = 'integer[]' AND
class_values IS NULL,
'Summary Validation failed. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessor_dl_batch_summary)
summary;
+
+-- Test cases with NULL in class values
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_input_null;
+CREATE TABLE minibatch_preprocessor_dl_input_null(id serial, x double
precision[], label TEXT);
+INSERT INTO minibatch_preprocessor_dl_input_null(x, label) VALUES
+(ARRAY[1,2,3,4,5,6], 'a'),
+(ARRAY[11,2,3,4,5,6], 'a'),
+(ARRAY[11,22,33,4,5,6], NULL),
+(ARRAY[11,22,33,44,5,6], 'a'),
+(ARRAY[11,22,33,44,65,6], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], NULL),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,44,65,56], NULL),
+(ARRAY[11,22,33,44,65,56], 'a'),
+(ARRAY[11,22,33,144,65,56], 'b'),
+(ARRAY[11,22,233,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], 'b'),
+(ARRAY[11,22,33,44,65,56], NULL);
+
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch,
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+ 'minibatch_preprocessor_dl_input_null',
+ 'minibatch_preprocessor_dl_batch',
+ 'label',
+ 'x',
+ 4,
+ 5,
+ 5 -- num_classes
+ );
+-- Test summary table if class_values has NULL as a legitimate
+-- class label, and also two other NULLs because num_classes=5
+-- but table has only 3 distinct class labels (including NULL)
+SELECT assert
+ (
+ class_values = '{NULL,a,b,NULL,NULL}',
+ 'Summary Validation failed with NULL data. Actual:' ||
__to_char(summary)
+ ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
+
+SELECT assert(array_upper(dependent_var, 2) = 5, 'Incorrect one-hot encode
dimension with NULL data') FROM
+ minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+
+-- Test the content of 1-hot encoded dep var when NULL is the
+-- class label.
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_input_null;
+CREATE TABLE minibatch_preprocessor_dl_input_null(id serial, x double
precision[], label TEXT);
+INSERT INTO minibatch_preprocessor_dl_input_null(x, label) VALUES
+(ARRAY[11,22,33,4,5,6], NULL);
+
+DROP TABLE IF EXISTS minibatch_preprocessor_dl_batch,
minibatch_preprocessor_dl_batch_summary;
+SELECT minibatch_preprocessor_dl(
+ 'minibatch_preprocessor_dl_input_null',
+ 'minibatch_preprocessor_dl_batch',
+ 'label',
+ 'x',
+ 4,
+ 5,
+ 3 -- num_classes
+ );
+
+-- class_values must be '{NULL,NULL,NULL}' where the first NULL
+-- is for the class label seen in data, and the other two NULLs
+-- are added as num_classes=3.
+SELECT assert
+ (
+ class_values = '{NULL,NULL,NULL}',
+ 'Summary Validation failed with NULL data. Actual:' ||
__to_char(summary)
+ ) from (select * from minibatch_preprocessor_dl_batch_summary) summary;
+
+SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode
dimension with NULL data') FROM
+ minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
+-- NULL is treated as a class label, so it should show '1' for the
+-- first index
+SELECT assert(dependent_var = '{{1,0,0}}', 'Incorrect one-hot encode dimension
with NULL data') FROM
+ minibatch_preprocessor_dl_batch WHERE buffer_id = 0;
diff --git
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
index 2fdf082..e213562 100644
---
a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
+++
b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
@@ -343,74 +343,92 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_dep_var = "depvar"
self.default_ind_var = "indvar"
self.default_buffer_size = 5
- #self.dependent_vartype = "integer[]"
- #self.independent_vartype = "integer[]"
+ self.default_normalizing_const = 1.0
+ self.default_num_classes = None
import utilities.minibatch_preprocessing
self.module = utilities.minibatch_preprocessing
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
self.module.validate_module_input_params = Mock()
self.module.get_distinct_col_levels = Mock(return_value = [0,22,100])
- self.subject =
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- self.default_source_table,
- self.default_output_table,
- self.default_dep_var,
- self.default_ind_var,
- self.default_buffer_size)
+ self.subject = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
def tearDown(self):
self.module_patcher.stop()
def test_minibatch_preprocessor_dl_executes_query(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
- preprocessor_obj =
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- "input",
- "out",
-
self.default_dep_var,
-
self.default_ind_var,
-
self.default_buffer_size)
+ preprocessor_obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ "input",
+ "out",
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
preprocessor_obj.minibatch_preprocessor_dl()
def test_minibatch_preprocessor_null_buffer_size_executes_query(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
- preprocessor_obj =
self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- "input",
- "out",
-
self.default_dep_var,
-
self.default_ind_var,
- None)
+ preprocessor_obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ "input",
+ "out",
+ self.default_dep_var,
+ self.default_ind_var,
+ None,
+ self.default_normalizing_const,
+ self.default_num_classes)
self.module.MiniBatchBufferSizeCalculator.calculate_default_buffer_size =
Mock(return_value = 5)
preprocessor_obj.minibatch_preprocessor_dl()
def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
with self.assertRaises(plpy.PLPYException):
- self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- self.default_source_table,
- self.default_output_table,
- "y1,y2",
- self.default_ind_var,
- self.default_buffer_size)
+ self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ "y1,y2",
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
def test_minibatch_preprocessor_multiple_indep_var_raises_exception(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
with self.assertRaises(plpy.PLPYException):
- self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- self.default_source_table,
- self.default_output_table,
- self.default_dep_var,
- "x1,x2",
- self.default_buffer_size)
+ self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ "x1,x2",
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
def test_minibatch_preprocessor_buffer_size_zero_fails(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
with self.assertRaises(plpy.PLPYException):
- self.module.MiniBatchPreProcessorDL(self.default_schema_madlib,
- self.default_source_table,
- self.default_output_table,
- self.default_dep_var,
- self.default_ind_var,
- 0)
+ self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ 0,
+ self.default_normalizing_const,
+ self.default_num_classes)
def test_minibatch_preprocessor_negative_buffer_size_fails(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
@@ -420,7 +438,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
- -1)
+ -1,
+ self.default_normalizing_const,
+ self.default_num_classes)
def
test_minibatch_preprocessor_invalid_indep_vartype_raises_exception(self):
self.module.get_expr_type = Mock(side_effect = ['integer',
'integer[]'])
@@ -430,7 +450,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
- self.default_buffer_size)
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
def test_minibatch_preprocessor_invalid_dep_vartype_raises_exception(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text[]'])
@@ -440,7 +462,9 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_output_table,
self.default_dep_var,
self.default_ind_var,
- self.default_buffer_size)
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
def test_minibatch_preprocessor_normalizing_const_zero_fails(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
@@ -451,7 +475,8 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_dep_var,
self.default_ind_var,
self.default_buffer_size,
- 0)
+ 0,
+ self.default_num_classes)
def test_minibatch_preprocessor_negative_normalizing_const_fails(self):
self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
@@ -462,7 +487,82 @@ class MiniBatchPreProcessingDLTestCase(unittest.TestCase):
self.default_dep_var,
self.default_ind_var,
self.default_buffer_size,
- -1)
+ -1,
+ self.default_num_classes)
+
+ def test_get_one_hot_encoded_dep_var_expr_null_val(self):
+ self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+ self.module.get_distinct_col_levels = Mock(return_value = ["NULL",
"'a'"])
+ obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
+ dep_var_array_expr = obj.get_one_hot_encoded_dep_var_expr()
+ self.assertEqual("array[({0}) is not distinct from null, ({0}) is not
distinct from 'a']::integer[]".
+ format(self.default_dep_var),
+ dep_var_array_expr.lower())
+
+ def test_get_one_hot_encoded_dep_var_expr_numeric_array_val(self):
+ self.module.get_expr_type = Mock(side_effect = ['integer[]',
'integer[]'])
+ obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ self.default_num_classes)
+ dep_var_array_expr = obj.get_one_hot_encoded_dep_var_expr()
+ self.assertEqual("{0}::integer[]".
+ format(self.default_dep_var),
+ dep_var_array_expr.lower())
+
+ def test_validate_num_classes_none(self):
+ self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+ obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ None)
+ self.assertEqual(0, obj.padding_size)
+
+ def test_validate_num_classes_greater(self):
+ self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+ self.module.dependent_levels = Mock(return_value = ["'a'", "'b'",
"'c'"])
+ obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ 5)
+ self.assertEqual(2, obj.padding_size)
+
+ def test_validate_num_classes_lesser(self):
+ self.module.get_expr_type = Mock(side_effect = ['integer[]', 'text'])
+ self.module.dependent_levels = Mock(return_value = ["'a'", "'b'",
"'c'"])
+ with self.assertRaises(plpy.PLPYException):
+ obj = self.module.MiniBatchPreProcessorDL(
+ self.default_schema_madlib,
+ self.default_source_table,
+ self.default_output_table,
+ self.default_dep_var,
+ self.default_ind_var,
+ self.default_buffer_size,
+ self.default_normalizing_const,
+ 2)
if __name__ == '__main__':
unittest.main()