Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/incubator-madlib/pull/77#discussion_r91395290
  
    --- Diff: src/ports/postgres/modules/elastic_net/elastic_net.py_in ---
    @@ -508,54 +758,77 @@ def analyze_single_input_str(schema_madlib, 
tbl_source, col_ind_var,
         else:
             plpy.error("Elastic Net error: Single column name included for "
                        "independent variable is not found in source table.")
    -# ========================================================================
    +# ------------------------------------------------------------------------
     
     
     def elastic_net_predict_all(schema_madlib, tbl_model, tbl_new_source,
                                 col_id, tbl_predict, **kwargs):
         """
         Predict and put the result in a table. Useful for general CV
         """
    -    old_msg_level = set_client_min_messages("error")
    -    regress_family = plpy.execute("SELECT family FROM {tbl_model} ".
    -                                  format(tbl_model=tbl_model))[0]["family"]
    -
    -    if regress_family.lower() in ("gaussian", "linear"):
    -        predict_func = "elastic_net_gaussian_predict"
    -    elif regress_family.lower() in ("binomial", "logistic"):
    -        predict_func = "elastic_net_binomial_predict"
    -    else:
    -        plpy.error("Elastic Net error: Not a supported response family!")
    +    summary_table = add_postfix(tbl_model, "_summary")
    +    grouping_col = plpy.execute("SELECT grouping_col FROM {summary_table}".
    +                                
format(summary_table=summary_table))[0]["grouping_col"]
    +    with MinWarning("error"):
    +        regress_family = plpy.execute("SELECT family FROM {tbl_model} ".
    +                                      
format(tbl_model=tbl_model))[0]["family"]
    +
    +        if regress_family.lower() in ("gaussian", "linear"):
    +            predict_func = "elastic_net_gaussian_predict"
    +        elif regress_family.lower() in ("binomial", "logistic"):
    +            predict_func = "elastic_net_binomial_predict"
    +        else:
    +            plpy.error("Elastic Net error: Not a supported response 
family!")
     
    -    if col_id is None or col_id == '':
    -        plpy.error("Elastic Net error: invalid ID column provided!")
    -    if columns_exist_in_table(tbl_new_source, [col_id], schema_madlib):
    -        elastic_net_predict_id = col_id
    -    else:
    -        elastic_net_predict_id = 'elastic_net_predict_id'
    -
    -    dense_vars = mad_vec(plpy.execute(""" SELECT features AS fs
    -                                          FROM {tbl_model}
    -                                      
""".format(tbl_model=tbl_model))[0]["fs"])
    -    dense_vars_str = "ARRAY[" + ", ".join(dense_vars) + "]"
    -    # Must be careful to avoid possible name conflicts
    -    plpy.execute(
    -        """
    -        DROP TABLE IF EXISTS {tbl_predict};
    -        CREATE TABLE {tbl_predict} AS
    -            SELECT
    -                {elastic_net_predict_id},
    -                {schema_madlib}.{predict_func}(coef_all, intercept, 
ind_var)
    -                     AS prediction
    -            FROM
    -                {tbl_model} as tbl1,
    -                (SELECT
    -                    {col_id} as {elastic_net_predict_id},
    -                    {dense_vars_str} as ind_var
    +        if col_id is None or col_id == '':
    +            plpy.error("Elastic Net error: invalid ID column provided!")
    +        if columns_exist_in_table(tbl_new_source, [col_id], schema_madlib):
    +            elastic_net_predict_id = col_id
    +        else:
    +            elastic_net_predict_id = 'elastic_net_predict_id'
    +
    +        dense_vars = plpy.execute(""" SELECT features AS fs
    +                                  FROM {tbl_model}
    +                                  """.format(tbl_model=tbl_model))[0]["fs"]
    +        dense_vars_str = "ARRAY[" + ", ".join(dense_vars) + "]"
    +        # Must be careful to avoid possible name conflicts
    +
    +        if not grouping_col or grouping_col != 'NULL':
    +            qstr = """
    +                DROP TABLE IF EXISTS {tbl_predict};
    +                CREATE TABLE {tbl_predict} AS
    +                    SELECT
    +                        {elastic_net_predict_id},
    +                        {schema_madlib}.{predict_func}(coef_all, 
intercept, ind_var)
    +                             AS prediction
    +                    FROM
    +                        {tbl_model} as tbl1
    +                        JOIN
    +                        (SELECT
    +                            {grouping_col},
    +                            {col_id} as {elastic_net_predict_id},
    +                            {dense_vars_str} as ind_var
    +                        FROM
    +                            {tbl_new_source}) tbl2
    +                        USING ({grouping_col})
    +                        ORDER BY {grouping_col}, {elastic_net_predict_id}
    +                """.format(**locals())
    +        else:
    +            qstr = """
    +            DROP TABLE IF EXISTS {tbl_predict};
    +            CREATE TABLE {tbl_predict} AS
    +                SELECT
    +                    {elastic_net_predict_id},
    +                    {schema_madlib}.{predict_func}(coef_all, intercept, 
ind_var)
    +                         AS prediction
                     FROM
    -                    {tbl_new_source}) tbl2
    -        """.format(**locals()))
    -
    -    set_client_min_messages(old_msg_level)
    +                    {tbl_model} as tbl1,
    +                    (SELECT
    +                        {col_id} as {elastic_net_predict_id},
    +                        {dense_vars_str} as ind_var
    +                    FROM
    --- End diff --
    
    There seems to be an issue here. I tried running the code with cross 
validation and 'array[tax, bath, size]' as my independent variables (these 
columns were present in the example dataset I considered). When 
elastic_net_predict_all is called via the cross validation code, I get the 
following error:
    ```
    ERROR:  spiexceptions.UndefinedColumn: column "tax" does not exist
    LINE 12:                         ARRAY[tax, bath, size] as ind_var
                                           ^
    QUERY:
                DROP TABLE IF EXISTS 
__madlib_temp_output_table83454748_1481146339_4273906__;
                CREATE TABLE 
__madlib_temp_output_table83454748_1481146339_4273906__ AS
                    SELECT
                        __madlib_temp_col_id34438449_1481146339_6272111__,
                        madlib.elastic_net_gaussian_predict(coef_all, 
intercept, ind_var)
                             AS prediction
                    FROM
                        houses_en as tbl1,
                        (SELECT
                            __madlib_temp_col_id34438449_1481146339_6272111__ 
as __madlib_temp_col_id34438449_1481146339_6272111__,
                            ARRAY[tax, bath, size] as ind_var
                        FROM
                            
__madlib_temp_cv_valid_0_12190289_1481146339_59770039__) tbl2
    ``` 
    
    The problem seems to be that the variable names used in the temp table 
`__madlib_temp_cv_valid_0_12190289_1481146339_59770039__` are `y` and `x`. It 
should either be the actual column names instead of `x`, or `dense_vars_str` 
must be explicitly set to `x`. This should be the case only when this function 
is called via cross validation, and not as a stand alone function for 
prediction. The value of `dense_vars` obtained from the model table 
(`tbl_model`) will correctly have `{tax,bath,size}` in it for features. It's 
just that the temp table created during cross validation renames the 
independent variable name to `x`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to