Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/295#discussion_r204593622
  
    --- Diff: 
src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---
    @@ -1564,69 +1578,69 @@ def get_var_importance(schema_madlib, model_table, 
output_table, **kwargs):
         """
         # Validate parameters
         summary_table = add_postfix(model_table, "_summary")
    -    _validate_var_importance_input(model_table, summary_table, 
output_table)
    -    grouping_cols = plpy.execute(
    -            "SELECT grouping_cols FROM {summary_table}".
    -                format(**locals()))[0]['grouping_cols']
    -    grouping_cols_comma = ''
    -    if grouping_cols :
    -        grouping_cols_comma = add_postfix(grouping_cols,", ")
    -    is_RF = _is_model_for_RF(summary_table)
    +    _validate_var_importance_input(model_table,
    +                                   summary_table,
    +                                   output_table)
    +    is_RF = _is_random_forest_model(summary_table)
    +
    +    # 'importance_model_table' is the table containing variable importance 
values.
    +    # For RF, it is placed in <model_table>_group as opposed to 
<model_table>
    +    # for DT.
    +    importance_model_table = (model_table if not is_RF else
    +                              add_postfix(model_table, "_group"))
    +    grouping_cols = plpy.execute("SELECT grouping_cols FROM 
{summary_table}".
    +                                 format(**locals()))[0]['grouping_cols']
    +    if grouping_cols:
    +        grouping_cols_comma = add_postfix(grouping_cols, ", ")
    +    else:
    +        grouping_cols_comma = ''
    +    is_impurity_imp_col_present = _is_impurity_importance_in_model(
    +        importance_model_table, summary_table, is_RF=is_RF)
    +
    +    # convert importance to percentages
    +    normalization_target = 100.0
    +
    +    def _unnest_normalize(input_array_str):
    +        return ("""
    +            unnest({0}.normalize_sum_array({1}::double precision[],
    +                                           {2}::double precision))
    +            """.format(schema_madlib, input_array_str, 
normalization_target))
    +
         if is_RF:
    -        group_table = add_postfix(model_table, "_group")
    -        is_impurity_imp_col_present = 
_is_impurity_importance_in_group_table(
    -                                        group_table, summary_table)
    -        # The group table for >= 1.15 RF models should have a column named
    -        # impurity_var_importance if it was learnt with importance param 
True.
    -        # So set is_pre_1_15_RF_model to False if the column exists, and to
    -        # True if the column does not exist.
    -        is_pre_1_15_RF_model = False if is_impurity_imp_col_present else 
True
    -
    -        # Query to add oob variable importance for categorical vars
    -        if is_pre_1_15_RF_model:
    -            # In < 1.15 RF model, the variable importance was captured 
using two
    -            # different columns named 'cat_var_importance'
    -            plpy.execute(
    -                """ CREATE TABLE {output_table} AS
    -                    -- Add oob variable importance for categorical vars
    -                    SELECT {grouping_cols_comma}
    -                        unnest(regexp_split_to_array(cat_features, ',')) 
AS feature,
    -                        unnest(cat_var_importance) AS var_importance
    -                    FROM {group_table}, {summary_table}
    -                    UNION
    -                    -- Add oob variable importance for continuous vars
    -                    SELECT {grouping_cols_comma}
    -                        unnest(regexp_split_to_array(con_features, ',')) 
AS feature,
    -                        unnest(con_var_importance) AS var_importance
    -                    FROM {group_table}, {summary_table}
    -                """.format(**locals()))
    +        if is_impurity_imp_col_present:
    +            # In versions >= 1.15, the OOB variable importance is captured
    +            # in a single column: 'oob_var_importance'.
    +            oob_var_importance_str = (
    +                "{0} AS oob_var_importance,".
    +                format(_unnest_normalize('oob_var_importance')))
    +            impurity_var_importance_str = (
    +                "{0} AS impurity_var_importance".
    +                format(_unnest_normalize('impurity_var_importance')))
             else:
    -            # In >= 1.15 RF models, the variable importance and impurity
    -            # importance scores are captured in columns 
'oob_var_importance'
    -            # and 'impurity_var_importance' respectively.
    -            plpy.execute(
    -                """ CREATE TABLE {output_table} AS
    -                    SELECT {grouping_cols_comma}
    -                        unnest(regexp_split_to_array(independent_varnames, 
',')) AS feature,
    -                        unnest(oob_var_importance) AS oob_var_importance,
    -                        unnest(impurity_var_importance) AS 
impurity_var_importance
    -                        FROM {group_table}, {summary_table}
    -                """.format(**locals()))
    -    else :
    -        # Fail if impurity importance does not exist in model table.
    -        _assert(columns_exist_in_table(model_table, 
['impurity_var_importance']),
    -                """Recursive Partitioning: Impurity variable importance """
    -                + """information does not exist in the model.""")
    -        # Query to add impurity variable importance
    -        plpy.execute(
    -            """ CREATE TABLE {output_table} AS
    -                SELECT {grouping_cols_comma}
    -                    unnest(regexp_split_to_array(independent_varnames, 
',')) AS feature,
    -                    unnest(impurity_var_importance) AS 
impurity_var_importance
    -                FROM {model_table}, {summary_table}
    -            """.format(**locals()))
    +            # In versions < 1.15, the OOB variable importance was captured 
in
    +            # two different columns: 'cat_var_importance' and 
'con_var_importance'
    +            oob_var_importance_str = (
    +                "{0} AS oob_var_importance".
    +                format(_unnest_normalize(
    +                    "array_cat(cat_var_importance, con_var_importance")))
    +            impurity_var_importance_str = ''
    +    else:
    +        # Decision tree models don't have a OOB variable importance
    +        oob_var_importance_str = ''
    +        impurity_var_importance_str = (
    +            "{0} AS impurity_var_importance".
    +            format(_unnest_normalize('impurity_var_importance')))
    +
    +    plpy.execute("""
    +        CREATE TABLE {output_table} AS
    +            SELECT {grouping_cols_comma}
    +                unnest(regexp_split_to_array(independent_varnames, ',')) 
AS feature,
    +                {oob_var_importance_str}
    +                {impurity_var_importance_str}
    +            FROM {importance_model_table}, {summary_table}
    +        """.format(**locals()))
    +# 
------------------------------------------------------------------------------
    --- End diff --
    
    +1


---

Reply via email to