Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/248#discussion_r177495864 --- Diff: src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in --- @@ -970,16 +970,35 @@ def _get_bins_grps( "one value are dropped from the tree model.") cat_features = [feature for feature in cat_features if feature in use_cat_features] - grp_to_col_to_row = dict((grp_key, dict( - (row['colname'], row['levels']) for row in items)) - for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))) - + # grp_col_to_levels is a list of tuples (pairs) with + # first value = group value, + # second value = a dict mapping a categorical column to its levels in data + # (these levels are specific to the group and can be different + # for different groups) + # The list of tuples can be converted to a dict, but the ordering + # will be lost. + # eg. grp_col_to_levels = + # [ + # ('3', {'vs': [0, 1], 'cyl': [4,6,8]}), + # ('4', {'vs': [0, 1], 'cyl': [4,6]}), + # ('5', {'vs': [0, 1], 'cyl': [4,6,8]}) + # ] + grp_to_col_to_levels = [ + (grp_key, dict((row['colname'], row['levels']) for row in items)) + for grp_key, items in groupby(all_levels, key=itemgetter('grp_key'))] if cat_features: - cat_items_list = [rows[col] for col in cat_features - for grp_key, rows in grp_to_col_to_row.items() if col in rows] + # Below statements collect the grp_to_col_to_levels into multiple variables + # From above eg. + # cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]] + # cat_n = [2, 3, 2, 2, 2, 3] + # cat_n = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8] + # grp_key_cat = ['3', '4', '5'] --- End diff -- +1 for the examples, very helpful.
---