MiniBatch Preprocessor: Check for all character types for dependent col This commit enables support for dependent column type to be any of the postgres character types instead of just `text`.
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b8863813 Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b8863813 Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b8863813 Branch: refs/heads/master Commit: b886381303c0cd4deff9468558a95455a64f4699 Parents: c902cb6 Author: Nikhil Kak <n...@pivotal.io> Authored: Fri Apr 6 11:42:41 2018 -0700 Committer: Nandish Jayaram <njaya...@apache.org> Committed: Fri Apr 13 17:16:50 2018 -0700 ---------------------------------------------------------------------- .../modules/utilities/minibatch_preprocessing.py_in | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/b8863813/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in index 1c53a59..856c7e4 100644 --- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in +++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in @@ -30,6 +30,8 @@ from utilities import add_postfix from utilities import _assert from utilities import get_seg_number from utilities import is_platform_pg +from utilities import is_psql_boolean_type +from utilities import is_psql_char_type from utilities import is_psql_numeric_type from utilities import is_psql_int_type from utilities import is_string_formatted_as_array_expression @@ -288,7 +290,8 @@ class MiniBatchQueryFormatter: """ dep_var_class_value_str = 'NULL::TEXT' is_dep_var_int_type = is_psql_int_type(dependent_var_dbtype) - to_one_hot_encode = (dependent_var_dbtype in ("text", "boolean") or + to_one_hot_encode = (is_psql_char_type(dependent_var_dbtype) or + is_psql_boolean_type(dependent_var_dbtype) or (to_one_hot_encode_int and is_dep_var_int_type)) if to_one_hot_encode: @@ -314,8 +317,9 @@ class MiniBatchQueryFormatter: elif is_psql_numeric_type(dependent_var_dbtype): dep_var_array_str = 'ARRAY[{0}]'.format(dependent_varname) else: - plpy.error("Invalid dependent variable type. It should be text, " - "boolean, numeric, or array.") + plpy.error("""Invalid dependent variable type. It should be character, + boolean, numeric, or array.""") + return dep_var_array_str, dep_var_class_value_str def _get_one_hot_encoded_str(self, var_name, var_classes, to_quote=True):