MLP: Fix bug in array dep var for regression

MLP training for regression fails if the dependent var type is an array.
This is because the variable dependent_varname was not updated to
reflect the new column created in the standardized table. This commit
fixes that issue.

Co-authored-by: Nikhil Kak <n...@pivotal.io>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/47eefc1c
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/47eefc1c
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/47eefc1c

Branch: refs/heads/master
Commit: 47eefc1c91a33db0a788c2c145b8016b892de5ad
Parents: 5a71ff6
Author: Nandish Jayaram <njaya...@apache.org>
Authored: Wed Apr 4 14:42:38 2018 -0700
Committer: Nandish Jayaram <njaya...@apache.org>
Committed: Tue Apr 10 11:14:21 2018 -0700

----------------------------------------------------------------------
 src/ports/postgres/modules/convex/mlp_igd.py_in | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/47eefc1c/src/ports/postgres/modules/convex/mlp_igd.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in 
b/src/ports/postgres/modules/convex/mlp_igd.py_in
index 8010579..2799355 100644
--- a/src/ports/postgres/modules/convex/mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/mlp_igd.py_in
@@ -155,21 +155,21 @@ def mlp(schema_madlib, source_table, output_table, 
independent_varname,
         normalize_data(locals())
         dependent_vartype = get_expr_type(dependent_varname, source_table)
 
+        # We are now using tbl_data_scaled, so change the dependent
+        # varname accordingly.
+        dependent_varname = col_dep_var_norm_new
         if is_classification:
             # If dependent variable is an array during classification, assume
             # that it is already one-hot-encoded.
             if "[]" in dependent_vartype:
-                # We are now using tbl_data_scaled, so change the dependent
-                # varname accordingly.
-                dependent_varname = col_dep_var_norm_new
                 num_output_nodes = get_col_dimension(tbl_data_scaled,
                                                      dependent_varname)
             else:
                 labels = plpy.execute("SELECT DISTINCT {0} FROM {1}".
-                                      format(dependent_varname, source_table))
+                                      format(dependent_varname_backup, 
source_table))
                 num_output_nodes = len(labels)
                 for label_obj in labels:
-                    label = _format_label(label_obj[dependent_varname])
+                    label = _format_label(label_obj[dependent_varname_backup])
                     classes.append(label)
                 classes.sort()
                 level_vals_str = ','.join(["{0}={1}".format(

Reply via email to