Github user njayaram2 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/277#discussion_r194918142
--- Diff:
src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---
@@ -1097,28 +1121,21 @@ def _one_step(schema_madlib, training_table_name,
cat_features,
"$3", "$2",
null_proxy)
- # The arguments of the aggregate (in the same order):
- # 1. current tree state, madlib.bytea8
- # 2. categorical features (integer format) in a single array
- # 3. continuous features in a single array
- # 4. weight value
- # 5. categorical sorted levels (integer format) in a combined array
- # 6. continuous splits
- # 7. number of dependent levels
train_sql = """
SELECT (result).* from (
SELECT
- {schema_madlib}._dt_apply($1,
+ {schema_madlib}._dt_apply(
+ $1,
{schema_madlib}._compute_leaf_stats(
- $1,
- {cat_features_str},
- {con_features_str},
+ $1, -- current tree state,
madlib.bytea8
+ {cat_features_str}, -- categorical features in an
array
+ {con_features_str}, -- continuous features in an
array
{dep_var},
- {weights},
- $2,
- $4,
- {dep_n_levels}::smallint,
- {subsample}::boolean
+ {weights}, -- weight value
+ $2, -- categorical sorted levels
in a combined array
+ $4, -- continuous splits
+ {dep_n_levels}::smallint, -- number of dependent
levels
+ {subsample}::boolean -- should we use a subsample
of data
--- End diff --
We only use `$1, $2, and $4` in this query. Can we remove the entry
associated with `$3` (guess it refers to `bins[cat_origin]`) from prepare and
execute statements that follow?
---