Github user iyerr3 commented on a diff in the pull request: https://github.com/apache/madlib/pull/295#discussion_r203816091 --- Diff: src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in --- @@ -2327,6 +2328,110 @@ def _tree_error(schema_madlib, source_table, dependent_varname, plpy.execute(sql) # ------------------------------------------------------------ +def _validate_var_importance_input(model_table, summary_table, output_table): + _assert(table_exists(model_table), + "Recursive Partitioning: Model table does not exist.") + _assert(table_exists(summary_table), + "Recursive Partitioning: Model summary table does not exist.") + _assert(not table_exists(output_table), + "Recursive Partitioning: Output table already exists.") + +def _is_model_for_RF(summary_table): + # Only an RF model (and not DT) would have num_trees column in summary + return columns_exist_in_table(summary_table, ['num_trees']) + +def _is_RF_model_with_imp_pre_1_15(group_table, summary_table): + """ + Check if the RF model is from MADlib < 1.15. The group table for + >= 1.15 RF models should have a column named impurity_var_importance + if it was learnt with importance param True. + """ + _assert(table_exists(group_table), + "Recursive Partitioning: Model group table does not exist.") + # this flag has to be set to true for RF to report importance scores. + isImportance = plpy.execute("SELECT importance FROM {summary_table}". + format(**locals()))[0]['importance'] + _assert(isImportance, """Recursive Partitioning: The model does """ + + """not have the importance information.""") + if columns_exist_in_table(group_table, ['impurity_var_importance']): + # If this column exists, then the RF model is >=1.15. + return False + else: + return True + +def get_var_importance(schema_madlib, model_table, output_table, **kwargs): + """ Create table capturing importance scores for each feature. + For DT, this function will record the impurity importance score if it exists. + For RF, this function will record the oob variable importance and impurity + importance (only for models learnt with 1.15 onwards) for each variable. + + Args: + @param schema_madlib: str, MADlib schema name + @param model_table: str, Model table name + @param output_table: str, Output table name + + """ + # Validate parameters + summary_table = add_postfix(model_table, "_summary") + _validate_var_importance_input(model_table, summary_table, output_table) + grouping_cols = plpy.execute( + "SELECT grouping_cols FROM {summary_table}". + format(**locals()))[0]['grouping_cols'] + grouping_cols_comma = '' + if grouping_cols : + grouping_cols_comma = add_postfix(grouping_cols,", ") + is_RF = _is_model_for_RF(summary_table) + if is_RF: + group_table = add_postfix(model_table, "_group") + is_pre_1_15_RF_model = _is_RF_model_with_imp_pre_1_15( + group_table, summary_table) + + # Query to add oob variable importance for categorical vars + if is_pre_1_15_RF_model: + # In < 1.15 RF model, the variable importance was captured using two + # different columns named 'cat_var_importance' + plpy.execute( + """ CREATE TABLE {output_table} AS + SELECT {grouping_cols_comma} + unnest(regexp_split_to_array(cat_features, ',')) AS feature, + unnest(cat_var_importance) AS var_importance + FROM {group_table}, {summary_table} + """.format(**locals())) --- End diff -- IMO, it's cleaner to see the two queries `UNION`ed together to get the output.
---