Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/295#discussion_r204593622 --- Diff: src/ports/postgres/modules/recursive_partitioning/random_forest.py_in --- @@ -1564,69 +1578,69 @@ def get_var_importance(schema_madlib, model_table, output_table, **kwargs): """ # Validate parameters summary_table = add_postfix(model_table, "_summary") - _validate_var_importance_input(model_table, summary_table, output_table) - grouping_cols = plpy.execute( - "SELECT grouping_cols FROM {summary_table}". - format(**locals()))[0]['grouping_cols'] - grouping_cols_comma = '' - if grouping_cols : - grouping_cols_comma = add_postfix(grouping_cols,", ") - is_RF = _is_model_for_RF(summary_table) + _validate_var_importance_input(model_table, + summary_table, + output_table) + is_RF = _is_random_forest_model(summary_table) + + # 'importance_model_table' is the table containing variable importance values. + # For RF, it is placed in <model_table>_group as opposed to <model_table> + # for DT. + importance_model_table = (model_table if not is_RF else + add_postfix(model_table, "_group")) + grouping_cols = plpy.execute("SELECT grouping_cols FROM {summary_table}". + format(**locals()))[0]['grouping_cols'] + if grouping_cols: + grouping_cols_comma = add_postfix(grouping_cols, ", ") + else: + grouping_cols_comma = '' + is_impurity_imp_col_present = _is_impurity_importance_in_model( + importance_model_table, summary_table, is_RF=is_RF) + + # convert importance to percentages + normalization_target = 100.0 + + def _unnest_normalize(input_array_str): + return (""" + unnest({0}.normalize_sum_array({1}::double precision[], + {2}::double precision)) + """.format(schema_madlib, input_array_str, normalization_target)) + if is_RF: - group_table = add_postfix(model_table, "_group") - is_impurity_imp_col_present = _is_impurity_importance_in_group_table( - group_table, summary_table) - # The group table for >= 1.15 RF models should have a column named - # impurity_var_importance if it was learnt with importance param True. - # So set is_pre_1_15_RF_model to False if the column exists, and to - # True if the column does not exist. - is_pre_1_15_RF_model = False if is_impurity_imp_col_present else True - - # Query to add oob variable importance for categorical vars - if is_pre_1_15_RF_model: - # In < 1.15 RF model, the variable importance was captured using two - # different columns named 'cat_var_importance' - plpy.execute( - """ CREATE TABLE {output_table} AS - -- Add oob variable importance for categorical vars - SELECT {grouping_cols_comma} - unnest(regexp_split_to_array(cat_features, ',')) AS feature, - unnest(cat_var_importance) AS var_importance - FROM {group_table}, {summary_table} - UNION - -- Add oob variable importance for continuous vars - SELECT {grouping_cols_comma} - unnest(regexp_split_to_array(con_features, ',')) AS feature, - unnest(con_var_importance) AS var_importance - FROM {group_table}, {summary_table} - """.format(**locals())) + if is_impurity_imp_col_present: + # In versions >= 1.15, the OOB variable importance is captured + # in a single column: 'oob_var_importance'. + oob_var_importance_str = ( + "{0} AS oob_var_importance,". + format(_unnest_normalize('oob_var_importance'))) + impurity_var_importance_str = ( + "{0} AS impurity_var_importance". + format(_unnest_normalize('impurity_var_importance'))) else: - # In >= 1.15 RF models, the variable importance and impurity - # importance scores are captured in columns 'oob_var_importance' - # and 'impurity_var_importance' respectively. - plpy.execute( - """ CREATE TABLE {output_table} AS - SELECT {grouping_cols_comma} - unnest(regexp_split_to_array(independent_varnames, ',')) AS feature, - unnest(oob_var_importance) AS oob_var_importance, - unnest(impurity_var_importance) AS impurity_var_importance - FROM {group_table}, {summary_table} - """.format(**locals())) - else : - # Fail if impurity importance does not exist in model table. - _assert(columns_exist_in_table(model_table, ['impurity_var_importance']), - """Recursive Partitioning: Impurity variable importance """ - + """information does not exist in the model.""") - # Query to add impurity variable importance - plpy.execute( - """ CREATE TABLE {output_table} AS - SELECT {grouping_cols_comma} - unnest(regexp_split_to_array(independent_varnames, ',')) AS feature, - unnest(impurity_var_importance) AS impurity_var_importance - FROM {model_table}, {summary_table} - """.format(**locals())) + # In versions < 1.15, the OOB variable importance was captured in + # two different columns: 'cat_var_importance' and 'con_var_importance' + oob_var_importance_str = ( + "{0} AS oob_var_importance". + format(_unnest_normalize( + "array_cat(cat_var_importance, con_var_importance"))) + impurity_var_importance_str = '' + else: + # Decision tree models don't have a OOB variable importance + oob_var_importance_str = '' + impurity_var_importance_str = ( + "{0} AS impurity_var_importance". + format(_unnest_normalize('impurity_var_importance'))) + + plpy.execute(""" + CREATE TABLE {output_table} AS + SELECT {grouping_cols_comma} + unnest(regexp_split_to_array(independent_varnames, ',')) AS feature, + {oob_var_importance_str} + {impurity_var_importance_str} + FROM {importance_model_table}, {summary_table} + """.format(**locals())) +# ------------------------------------------------------------------------------ --- End diff -- +1
---