MiniBatch Preprocessor: Check for all character types for dependent col

This commit enables support for dependent column type
to be any of the postgres character types instead of just `text`.


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b8863813
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b8863813
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b8863813

Branch: refs/heads/master
Commit: b886381303c0cd4deff9468558a95455a64f4699
Parents: c902cb6
Author: Nikhil Kak <n...@pivotal.io>
Authored: Fri Apr 6 11:42:41 2018 -0700
Committer: Nandish Jayaram <njaya...@apache.org>
Committed: Fri Apr 13 17:16:50 2018 -0700

----------------------------------------------------------------------
 .../modules/utilities/minibatch_preprocessing.py_in       | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/b8863813/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index 1c53a59..856c7e4 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -30,6 +30,8 @@ from utilities import add_postfix
 from utilities import _assert
 from utilities import get_seg_number
 from utilities import is_platform_pg
+from utilities import is_psql_boolean_type
+from utilities import is_psql_char_type
 from utilities import is_psql_numeric_type
 from utilities import is_psql_int_type
 from utilities import is_string_formatted_as_array_expression
@@ -288,7 +290,8 @@ class MiniBatchQueryFormatter:
         """
         dep_var_class_value_str = 'NULL::TEXT'
         is_dep_var_int_type = is_psql_int_type(dependent_var_dbtype)
-        to_one_hot_encode = (dependent_var_dbtype in ("text", "boolean") or
+        to_one_hot_encode = (is_psql_char_type(dependent_var_dbtype) or
+                             is_psql_boolean_type(dependent_var_dbtype) or
                                 (to_one_hot_encode_int and
                                     is_dep_var_int_type))
         if to_one_hot_encode:
@@ -314,8 +317,9 @@ class MiniBatchQueryFormatter:
         elif is_psql_numeric_type(dependent_var_dbtype):
             dep_var_array_str = 'ARRAY[{0}]'.format(dependent_varname)
         else:
-            plpy.error("Invalid dependent variable type. It should be text, "
-                       "boolean, numeric, or array.")
+            plpy.error("""Invalid dependent variable type. It should be 
character,
+                boolean, numeric, or array.""")
+
         return dep_var_array_str, dep_var_class_value_str
 
     def _get_one_hot_encoded_str(self, var_name, var_classes, to_quote=True):

Reply via email to