Github user pralabhkumar commented on a diff in the pull request:
https://github.com/apache/spark/pull/18118#discussion_r148197285
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
---
@@ -118,11 +119,12 @@ class DecisionTreeRegressor @Since("1.4.0")
(@Since("1.4.0") override val uid: S
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
- oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
+ oldStrategy: OldStrategy, featureSubsetStrategy: String):
DecisionTreeRegressionModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all",
+ val trees = RandomForest.run(data, oldStrategy, numTrees = 1,
+ featureSubsetStrategy,
--- End diff --
done
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]