Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/295#discussion_r203816091
--- Diff:
src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---
@@ -2327,6 +2328,110 @@ def _tree_error(schema_madlib, source_table,
dependent_varname,
plpy.execute(sql)
# ------------------------------------------------------------
+def _validate_var_importance_input(model_table, summary_table,
output_table):
+ _assert(table_exists(model_table),
+ "Recursive Partitioning: Model table does not exist.")
+ _assert(table_exists(summary_table),
+ "Recursive Partitioning: Model summary table does not exist.")
+ _assert(not table_exists(output_table),
+ "Recursive Partitioning: Output table already exists.")
+
+def _is_model_for_RF(summary_table):
+ # Only an RF model (and not DT) would have num_trees column in summary
+ return columns_exist_in_table(summary_table, ['num_trees'])
+
+def _is_RF_model_with_imp_pre_1_15(group_table, summary_table):
+ """
+ Check if the RF model is from MADlib < 1.15. The group table for
+ >= 1.15 RF models should have a column named
impurity_var_importance
+ if it was learnt with importance param True.
+ """
+ _assert(table_exists(group_table),
+ "Recursive Partitioning: Model group table does not exist.")
+ # this flag has to be set to true for RF to report importance scores.
+ isImportance = plpy.execute("SELECT importance FROM {summary_table}".
+ format(**locals()))[0]['importance']
+ _assert(isImportance, """Recursive Partitioning: The model does """
+ + """not have the importance information.""")
+ if columns_exist_in_table(group_table, ['impurity_var_importance']):
+ # If this column exists, then the RF model is >=1.15.
+ return False
+ else:
+ return True
+
+def get_var_importance(schema_madlib, model_table, output_table, **kwargs):
+ """ Create table capturing importance scores for each feature.
+ For DT, this function will record the impurity importance score if
it exists.
+ For RF, this function will record the oob variable importance and
impurity
+ importance (only for models learnt with 1.15 onwards) for each
variable.
+
+ Args:
+ @param schema_madlib: str, MADlib schema name
+ @param model_table: str, Model table name
+ @param output_table: str, Output table name
+
+ """
+ # Validate parameters
+ summary_table = add_postfix(model_table, "_summary")
+ _validate_var_importance_input(model_table, summary_table,
output_table)
+ grouping_cols = plpy.execute(
+ "SELECT grouping_cols FROM {summary_table}".
+ format(**locals()))[0]['grouping_cols']
+ grouping_cols_comma = ''
+ if grouping_cols :
+ grouping_cols_comma = add_postfix(grouping_cols,", ")
+ is_RF = _is_model_for_RF(summary_table)
+ if is_RF:
+ group_table = add_postfix(model_table, "_group")
+ is_pre_1_15_RF_model = _is_RF_model_with_imp_pre_1_15(
+ group_table, summary_table)
+
+ # Query to add oob variable importance for categorical vars
+ if is_pre_1_15_RF_model:
+ # In < 1.15 RF model, the variable importance was captured
using two
+ # different columns named 'cat_var_importance'
+ plpy.execute(
+ """ CREATE TABLE {output_table} AS
+ SELECT {grouping_cols_comma}
+ unnest(regexp_split_to_array(cat_features, ','))
AS feature,
+ unnest(cat_var_importance) AS var_importance
+ FROM {group_table}, {summary_table}
+ """.format(**locals()))
--- End diff --
IMO, it's cleaner to see the two queries `UNION`ed together to get the
output.
---