Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/295#discussion_r203816091
  
    --- Diff: 
src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---
    @@ -2327,6 +2328,110 @@ def _tree_error(schema_madlib, source_table, 
dependent_varname,
             plpy.execute(sql)
     # ------------------------------------------------------------
     
    +def _validate_var_importance_input(model_table, summary_table, 
output_table):
    +    _assert(table_exists(model_table),
    +            "Recursive Partitioning: Model table does not exist.")
    +    _assert(table_exists(summary_table),
    +            "Recursive Partitioning: Model summary table does not exist.")
    +    _assert(not table_exists(output_table),
    +            "Recursive Partitioning: Output table already exists.")
    +
    +def _is_model_for_RF(summary_table):
    +    # Only an RF model (and not DT) would have num_trees column in summary
    +    return columns_exist_in_table(summary_table, ['num_trees'])
    +
    +def _is_RF_model_with_imp_pre_1_15(group_table, summary_table):
    +    """
    +        Check if the RF model is from MADlib < 1.15. The group table for
    +        >= 1.15 RF models should have a column named 
impurity_var_importance
    +        if it was learnt with importance param True.
    +    """
    +    _assert(table_exists(group_table),
    +        "Recursive Partitioning: Model group table does not exist.")
    +    # this flag has to be set to true for RF to report importance scores.
    +    isImportance = plpy.execute("SELECT importance FROM {summary_table}".
    +                    format(**locals()))[0]['importance']
    +    _assert(isImportance, """Recursive Partitioning: The model does """
    +            + """not have the importance information.""")
    +    if columns_exist_in_table(group_table, ['impurity_var_importance']):
    +        # If this column exists, then the RF model is >=1.15.
    +        return False
    +    else:
    +        return True
    +
    +def get_var_importance(schema_madlib, model_table, output_table, **kwargs):
    +    """ Create table capturing importance scores for each feature.
    +        For DT, this function will record the impurity importance score if 
it exists.
    +        For RF, this function will record the oob variable importance and 
impurity
    +        importance (only for models learnt with 1.15 onwards) for each 
variable.
    +
    +        Args:
    +            @param schema_madlib: str, MADlib schema name
    +            @param model_table: str, Model table name
    +            @param output_table: str, Output table name
    +
    +    """
    +    # Validate parameters
    +    summary_table = add_postfix(model_table, "_summary")
    +    _validate_var_importance_input(model_table, summary_table, 
output_table)
    +    grouping_cols = plpy.execute(
    +            "SELECT grouping_cols FROM {summary_table}".
    +                format(**locals()))[0]['grouping_cols']
    +    grouping_cols_comma = ''
    +    if grouping_cols :
    +        grouping_cols_comma = add_postfix(grouping_cols,", ")
    +    is_RF = _is_model_for_RF(summary_table)
    +    if is_RF:
    +        group_table = add_postfix(model_table, "_group")
    +        is_pre_1_15_RF_model = _is_RF_model_with_imp_pre_1_15(
    +                                    group_table, summary_table)
    +
    +        # Query to add oob variable importance for categorical vars
    +        if is_pre_1_15_RF_model:
    +            # In < 1.15 RF model, the variable importance was captured 
using two
    +            # different columns named 'cat_var_importance'
    +            plpy.execute(
    +                """ CREATE TABLE {output_table} AS
    +                    SELECT {grouping_cols_comma}
    +                        unnest(regexp_split_to_array(cat_features, ',')) 
AS feature,
    +                        unnest(cat_var_importance) AS var_importance
    +                    FROM {group_table}, {summary_table}
    +                """.format(**locals()))
    --- End diff --
    
    IMO, it's cleaner to see the two queries `UNION`ed together to get the 
output. 


---

Reply via email to