Github user njayaram2 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/295#discussion_r204593622
--- Diff:
src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---
@@ -1564,69 +1578,69 @@ def get_var_importance(schema_madlib, model_table,
output_table, **kwargs):
"""
# Validate parameters
summary_table = add_postfix(model_table, "_summary")
- _validate_var_importance_input(model_table, summary_table,
output_table)
- grouping_cols = plpy.execute(
- "SELECT grouping_cols FROM {summary_table}".
- format(**locals()))[0]['grouping_cols']
- grouping_cols_comma = ''
- if grouping_cols :
- grouping_cols_comma = add_postfix(grouping_cols,", ")
- is_RF = _is_model_for_RF(summary_table)
+ _validate_var_importance_input(model_table,
+ summary_table,
+ output_table)
+ is_RF = _is_random_forest_model(summary_table)
+
+ # 'importance_model_table' is the table containing variable importance
values.
+ # For RF, it is placed in <model_table>_group as opposed to
<model_table>
+ # for DT.
+ importance_model_table = (model_table if not is_RF else
+ add_postfix(model_table, "_group"))
+ grouping_cols = plpy.execute("SELECT grouping_cols FROM
{summary_table}".
+ format(**locals()))[0]['grouping_cols']
+ if grouping_cols:
+ grouping_cols_comma = add_postfix(grouping_cols, ", ")
+ else:
+ grouping_cols_comma = ''
+ is_impurity_imp_col_present = _is_impurity_importance_in_model(
+ importance_model_table, summary_table, is_RF=is_RF)
+
+ # convert importance to percentages
+ normalization_target = 100.0
+
+ def _unnest_normalize(input_array_str):
+ return ("""
+ unnest({0}.normalize_sum_array({1}::double precision[],
+ {2}::double precision))
+ """.format(schema_madlib, input_array_str,
normalization_target))
+
if is_RF:
- group_table = add_postfix(model_table, "_group")
- is_impurity_imp_col_present =
_is_impurity_importance_in_group_table(
- group_table, summary_table)
- # The group table for >= 1.15 RF models should have a column named
- # impurity_var_importance if it was learnt with importance param
True.
- # So set is_pre_1_15_RF_model to False if the column exists, and to
- # True if the column does not exist.
- is_pre_1_15_RF_model = False if is_impurity_imp_col_present else
True
-
- # Query to add oob variable importance for categorical vars
- if is_pre_1_15_RF_model:
- # In < 1.15 RF model, the variable importance was captured
using two
- # different columns named 'cat_var_importance'
- plpy.execute(
- """ CREATE TABLE {output_table} AS
- -- Add oob variable importance for categorical vars
- SELECT {grouping_cols_comma}
- unnest(regexp_split_to_array(cat_features, ','))
AS feature,
- unnest(cat_var_importance) AS var_importance
- FROM {group_table}, {summary_table}
- UNION
- -- Add oob variable importance for continuous vars
- SELECT {grouping_cols_comma}
- unnest(regexp_split_to_array(con_features, ','))
AS feature,
- unnest(con_var_importance) AS var_importance
- FROM {group_table}, {summary_table}
- """.format(**locals()))
+ if is_impurity_imp_col_present:
+ # In versions >= 1.15, the OOB variable importance is captured
+ # in a single column: 'oob_var_importance'.
+ oob_var_importance_str = (
+ "{0} AS oob_var_importance,".
+ format(_unnest_normalize('oob_var_importance')))
+ impurity_var_importance_str = (
+ "{0} AS impurity_var_importance".
+ format(_unnest_normalize('impurity_var_importance')))
else:
- # In >= 1.15 RF models, the variable importance and impurity
- # importance scores are captured in columns
'oob_var_importance'
- # and 'impurity_var_importance' respectively.
- plpy.execute(
- """ CREATE TABLE {output_table} AS
- SELECT {grouping_cols_comma}
- unnest(regexp_split_to_array(independent_varnames,
',')) AS feature,
- unnest(oob_var_importance) AS oob_var_importance,
- unnest(impurity_var_importance) AS
impurity_var_importance
- FROM {group_table}, {summary_table}
- """.format(**locals()))
- else :
- # Fail if impurity importance does not exist in model table.
- _assert(columns_exist_in_table(model_table,
['impurity_var_importance']),
- """Recursive Partitioning: Impurity variable importance """
- + """information does not exist in the model.""")
- # Query to add impurity variable importance
- plpy.execute(
- """ CREATE TABLE {output_table} AS
- SELECT {grouping_cols_comma}
- unnest(regexp_split_to_array(independent_varnames,
',')) AS feature,
- unnest(impurity_var_importance) AS
impurity_var_importance
- FROM {model_table}, {summary_table}
- """.format(**locals()))
+ # In versions < 1.15, the OOB variable importance was captured
in
+ # two different columns: 'cat_var_importance' and
'con_var_importance'
+ oob_var_importance_str = (
+ "{0} AS oob_var_importance".
+ format(_unnest_normalize(
+ "array_cat(cat_var_importance, con_var_importance")))
+ impurity_var_importance_str = ''
+ else:
+ # Decision tree models don't have a OOB variable importance
+ oob_var_importance_str = ''
+ impurity_var_importance_str = (
+ "{0} AS impurity_var_importance".
+ format(_unnest_normalize('impurity_var_importance')))
+
+ plpy.execute("""
+ CREATE TABLE {output_table} AS
+ SELECT {grouping_cols_comma}
+ unnest(regexp_split_to_array(independent_varnames, ','))
AS feature,
+ {oob_var_importance_str}
+ {impurity_var_importance_str}
+ FROM {importance_model_table}, {summary_table}
+ """.format(**locals()))
+#
------------------------------------------------------------------------------
--- End diff --
+1
---