Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/248#discussion_r177503935
--- Diff:
src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---
@@ -970,16 +970,35 @@ def _get_bins_grps(
"one value are dropped from the tree model.")
cat_features = [feature for feature in cat_features if feature
in use_cat_features]
- grp_to_col_to_row = dict((grp_key, dict(
- (row['colname'], row['levels']) for row in items))
- for grp_key, items in groupby(all_levels,
key=itemgetter('grp_key')))
-
+ # grp_col_to_levels is a list of tuples (pairs) with
+ # first value = group value,
+ # second value = a dict mapping a categorical column to its
levels in data
+ # (these levels are specific to the group and can be
different
+ # for different groups)
+ # The list of tuples can be converted to a dict, but the ordering
+ # will be lost.
+ # eg. grp_col_to_levels =
+ # [
+ # ('3', {'vs': [0, 1], 'cyl': [4,6,8]}),
+ # ('4', {'vs': [0, 1], 'cyl': [4,6]}),
+ # ('5', {'vs': [0, 1], 'cyl': [4,6,8]})
+ # ]
+ grp_to_col_to_levels = [
+ (grp_key, dict((row['colname'], row['levels']) for row in
items))
+ for grp_key, items in groupby(all_levels,
key=itemgetter('grp_key'))]
if cat_features:
- cat_items_list = [rows[col] for col in cat_features
- for grp_key, rows in grp_to_col_to_row.items()
if col in rows]
+ # Below statements collect the grp_to_col_to_levels into multiple
variables
+ # From above eg.
+ # cat_items_list = [[0,1], [4,6,8], [0,1], [4,6], [0,1], [4,6,8]]
+ # cat_n = [2, 3, 2, 2, 2, 3]
+ # cat_n = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8]
--- End diff --
Yes, thanks for the catch. I'll update before merging.
---