Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/248#discussion_r177495864
  
    --- Diff: 
src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---
    @@ -970,16 +970,35 @@ def _get_bins_grps(
                              "one value are dropped from the tree model.")
                 cat_features = [feature for feature in cat_features if feature 
in use_cat_features]
     
    -        grp_to_col_to_row = dict((grp_key, dict(
    -            (row['colname'], row['levels']) for row in items))
    -            for grp_key, items in groupby(all_levels, 
key=itemgetter('grp_key')))
    -
    +        # grp_col_to_levels is a list of tuples (pairs) with
    +        #   first value = group value,
    +        #   second value = a dict mapping a categorical column to its 
levels in data
    +        #           (these levels are specific to the group and can be 
different
    +        #               for different groups)
    +        #   The list of tuples can be converted to a dict, but the ordering
    +        #   will be lost.
    +        # eg. grp_col_to_levels =
    +        #   [
    +        #       ('3', {'vs': [0, 1], 'cyl': [4,6,8]}),
    +        #       ('4', {'vs': [0, 1], 'cyl': [4,6]}),
    +        #       ('5', {'vs': [0, 1], 'cyl': [4,6,8]})
    +        #   ]
    +        grp_to_col_to_levels = [
    +            (grp_key, dict((row['colname'], row['levels']) for row in 
items))
    +            for grp_key, items in groupby(all_levels, 
key=itemgetter('grp_key'))]
         if cat_features:
    -        cat_items_list = [rows[col] for col in cat_features
    -                          for grp_key, rows in grp_to_col_to_row.items() 
if col in rows]
    +        # Below statements collect the grp_to_col_to_levels into multiple 
variables
    +        # From above eg.
    +        #   cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]]
    +        #   cat_n = [2, 3, 2, 2, 2, 3]
    +        #   cat_n = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8]
    +        #   grp_key_cat = ['3', '4', '5']
    --- End diff --
    
    +1 for the examples, very helpful.


---

Reply via email to