Copilot commented on code in PR #55728:
URL: https://github.com/apache/spark/pull/55728#discussion_r3200111161


##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -712,17 +853,44 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     )
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
-    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
-      numClasses = 0, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Regression,
+      impurity = Variance,
+      maxDepth = 4,
+      numClasses = 0,
+      maxBins = 32)
+
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = true)
+      .head
+
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = false)
+      .head
+
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)

Review Comment:
   `RandomForest.run` no longer accepts a `prune` named argument, but this call 
still passes `prune = false`. This will not compile; set `strategy.pruneTree = 
false` on the `OldStrategy` instead.
   



##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params {
     (value: String) =>
       
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
-  setDefault(impurity -> "gini")
+  /**
+   * If true, the trained tree will undergo a 'pruning' process after training 
in which nodes
+   * that have the same class predictions will be merged.  This drawback means 
that the class
+   * probabilities will be lost.  The benefit being that at prediction time 
the tree will be
+   * smaller and have faster predictions
+   * If false, the post-training tree will undergo no pruning.  The benefit 
being that you
+   * maintain the class prediction probabilities
+   * (default = true)
+   * @group param
+   */
+  final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+    "If true, the trained tree will undergo a 'pruning' process after training 
in which nodes" +
+    " that have the same class predictions will be merged.  This drawback 
means that the class" +
+    " probabilities will be lost.  The benefit being that at prediction time 
the tree will be" +
+    " smaller and have faster predictions" +
+    " If false, the post-training tree will undergo no pruning.  The benefit 
being that you" +
+    " maintain the class prediction probabilities"

Review Comment:
   The `pruneTree` Scaladoc/help text reads as a single run-on sentence (e.g., 
"faster predictions If false") and contains grammar issues ("This drawback 
means..."). Please add punctuation and rephrase for clarity since this text 
becomes user-facing Param documentation.
   



##########
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala:
##########
@@ -225,7 +234,7 @@ private[spark] object RandomForest extends Logging with 
Serializable {
       timer.stop("findBestSplits")
 
       if (earlyStopModelSizeThresholdInBytes > 0) {
-        val nodes = topNodes.map(_.toNode(prune))
+        val nodes = topNodes.map(_.toNode(strategy.pruneTree))

Review Comment:
   Switching pruning from an explicit method parameter to `strategy.pruneTree` 
means pruning now depends on the `Strategy` default for all callers. Since 
`Strategy.pruneTree` defaults to `false`, this silently changes behavior for 
code paths which don’t explicitly set it (e.g., old mllib API, regressors, 
GBT). If the intent is only to make pruning configurable, consider defaulting 
`pruneTree` to `true` in `Strategy.defaultStrategy(...)` / `Strategy` 
constructor (or explicitly setting it in `DecisionTreeParams.getOldStrategy`) 
to preserve prior behavior.
   



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,44 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
     val numClasses = 2
-    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 4,
-      numClasses = numClasses, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Classification,
+      impurity = Gini,
+      maxDepth = 4,
+      numClasses = numClasses,
+      maxBins = 32)
+
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = true)
+      .head
+
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = false)
+      .head

Review Comment:
   `RandomForest.run` no longer accepts a `prune` named argument, but this call 
still passes `prune = false`. This will not compile; use `strategy.pruneTree = 
false` (or set it once before the call) to test the unpruned behavior.



##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -77,6 +79,7 @@ class Strategy @Since("1.3.0") (
     @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = 
Map[Int, Int](),
     @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
     @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
+    @Since("3.1.2") @BeanProperty var pruneTree: Boolean = false,

Review Comment:
   The new `pruneTree` field is annotated `@Since("3.1.2")`, but this 
repository is Spark 5.0.0 (see `python/pyspark/version.py`). Please update the 
`@Since` version to the correct release where this parameter is actually 
introduced.
   



##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,32 @@ private[ml] trait TreeClassifierParams extends Params {
     (value: String) =>
       
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
-  setDefault(impurity -> "gini")
+  /**
+   * If true, the trained tree will undergo a 'pruning' process after training 
in which nodes
+   * that have the same class predictions will be merged.  This drawback means 
that the class
+   * probabilities will be lost.  The benefit being that at prediction time 
the tree will be
+   * smaller and have faster predictions
+   * If false, the post-training tree will undergo no pruning.  The benefit 
being that you
+   * maintain the class prediction probabilities
+   * (default = true)
+   * @group param
+   */
+  final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+    "If true, the trained tree will undergo a 'pruning' process after training 
in which nodes" +
+    " that have the same class predictions will be merged.  This drawback 
means that the class" +
+    " probabilities will be lost.  The benefit being that at prediction time 
the tree will be" +
+    " smaller and have faster predictions" +
+    " If false, the post-training tree will undergo no pruning.  The benefit 
being that you" +
+    " maintain the class prediction probabilities"
+  )
+
+  // HERE

Review Comment:
   Leftover debug/comment marker `// HERE` should be removed before merging.
   



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -431,18 +512,32 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq)
 
     // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
-    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 1,
-      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
-    val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Classification,
+      impurity = Gini,
+      maxDepth = 1,
+      numClasses = 2,
+      categoricalFeaturesInfo = Map(0 -> 3),
+      maxBins = 3)
+
+    val model = RandomForest
+      .run(
+        input,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "all",
+        seed = 42,
+        instr = None,
+        prune = false)
+      .head

Review Comment:
   `RandomForest.run` no longer takes a `prune` parameter (it was removed from 
`ml/tree/impl/RandomForest.scala`), but this test still passes `prune = false` 
as a named argument. This will not compile; update the test to set 
`strategy.pruneTree` instead (or reintroduce the `prune` parameter if it’s 
still intended for testing).



##########
mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala:
##########
@@ -74,6 +74,10 @@ class DecisionTreeClassifier @Since("1.4.0") (
   @Since("1.4.0")
   def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
 
+  /** @group setParam */
+  @Since("3.1.2")

Review Comment:
   `pruneTree` is introduced as a new API but is annotated `@Since("3.1.2")`. 
This should be updated to the Spark version which will actually ship this 
change (the repo is currently 5.0.0).
   



##########
python/pyspark/ml/classification.py:
##########
@@ -1838,7 +1841,7 @@ def setParams(
         """
         setParams(self, \\*, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                   probabilityCol="probability", 
rawPredictionCol="rawPrediction", \
-                  maxDepth=5, maxBins=32, minInstancesPerNode=1, 
minInfoGain=0.0, \
+                  maxDepth=5, maxBins=32, minInstancesPerNode=1, 
minInfoGain=0.0, pruneTree=True\

Review Comment:
   The generated docstring signature is missing a comma/line-continuation 
between `pruneTree=True` and `maxMemoryInMB=...`, so the documented call 
signature is malformed. Please add the missing separator so the docstring 
matches the real Python signature.
   



##########
python/pyspark/ml/classification.py:
##########
@@ -2097,7 +2108,7 @@ def __init__(
         """
         __init__(self, \\*, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                  probabilityCol="probability", 
rawPredictionCol="rawPrediction", \
-                 maxDepth=5, maxBins=32, minInstancesPerNode=1, 
minInfoGain=0.0, \
+                 maxDepth=5, maxBins=32, minInstancesPerNode=1, 
minInfoGain=0.0, pruneTree=True\

Review Comment:
   The generated docstring signature is missing a comma/line-continuation 
between `pruneTree=True` and `maxMemoryInMB=...`, so the documented call 
signature is malformed. Please add the missing separator so the docstring 
matches the real Python signature.
   



##########
python/pyspark/ml/tree.py:
##########
@@ -424,6 +432,12 @@ def getImpurity(self) -> str:
         Gets the value of impurity or its default value.
         """
         return self.getOrDefault(self.impurity)
+    @since("3.1.2")
+    def getPruneTree(self):
+        """
+        Gets the value of pruneTree or its default value.
+        """
+        return self.getOrDefault(self.pruneTree)

Review Comment:
   `getPruneTree` should follow the typing and versioning conventions used by 
neighboring getters: add a return type annotation (`-> bool`) and update the 
`@since` version (currently `3.1.2`) to the Spark version where this new param 
is introduced (repo is 5.0.0).



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,44 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
     val numClasses = 2
-    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 4,
-      numClasses = numClasses, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Classification,
+      impurity = Gini,
+      maxDepth = 4,
+      numClasses = numClasses,
+      maxBins = 32)
+
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = true)
+      .head
+
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = false)
+      .head
+
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)

Review Comment:
   `RandomForest.run` no longer accepts a `prune` named argument, but this call 
still passes `prune = true`. This will not compile; use `strategy.pruneTree = 
true` for the call path being tested (and similarly update the corresponding 
unpruned call).
   



##########
python/pyspark/ml/tree.py:
##########
@@ -415,6 +415,14 @@ class _TreeClassifierParams(Params):
         typeConverter=TypeConverters.toString,
     )
 
+    pruneTree = Param(Params._dummy(), "pruneTree", "" +
+                      "If true, the trained tree will undergo a 'pruning' 
process after training in which nodes" +
+                      " that have the same class predictions will be merged.  
This drawback means that the class" +
+                      " probabilities will be lost.  The benefit being that at 
prediction time the tree will be" +
+                      " smaller and have faster predictions" +
+                      " If false, the post-training tree will undergo no 
pruning.  The benefit being that you" +
+                      " maintain the class prediction probabilities", 
typeConverter=TypeConverters.toBoolean)

Review Comment:
   This PR adds a new user-facing Python param (`pruneTree`) and new setter(s), 
but there’s no corresponding PySpark test asserting the param is 
exposed/round-trips correctly (e.g., `setPruneTree(False)` affects 
`getPruneTree()` / JVM param map, and default matches the intended behavior). 
Consider adding coverage in `python/pyspark/ml/tests/test_classification.py` 
alongside the existing DecisionTreeClassifier/RandomForestClassifier tests.



##########
mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala:
##########
@@ -76,6 +76,10 @@ class RandomForestClassifier @Since("1.4.0") (
   @Since("1.4.0")
   def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
 
+  /** @group setParam */
+  @Since("3.1.2")

Review Comment:
   `pruneTree` is introduced as a new API but is annotated `@Since("3.1.2")`. 
This should be updated to the Spark version which will actually ship this 
change (the repo is currently 5.0.0).
   



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -712,17 +853,44 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     )
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
-    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
-      numClasses = 0, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Regression,
+      impurity = Variance,
+      maxDepth = 4,
+      numClasses = 0,
+      maxBins = 32)
+
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = true)
+      .head
+
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None,
+        prune = false)
+      .head
+
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)

Review Comment:
   `RandomForest.run` no longer accepts a `prune` named argument, but this 
regression test still passes `prune = true`. This will not compile; set 
`strategy.pruneTree = true` before calling `run` (and mirror for the unpruned 
call).
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to