Copilot commented on code in PR #55763:
URL: https://github.com/apache/spark/pull/55763#discussion_r3208313912


##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -77,6 +79,7 @@ class Strategy @Since("1.3.0") (
     @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = 
Map[Int, Int](),
     @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
     @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
+    @Since("4.3.0") @BeanProperty var pruneTree: Boolean = true,
     @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,

Review Comment:
   Adding `pruneTree` as a new parameter in the primary `Strategy` constructor 
changes the public constructor signature (and shifts positional parameters). 
This breaks source/binary compatibility for user code constructing `new 
Strategy(...)` with positional arguments. To preserve compatibility, keep the 
existing constructor signature and add `pruneTree` as a separate `var` field 
(or add an overloaded/secondary constructor while leaving the original one 
intact).



##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
     (value: String) =>
       
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
-  setDefault(impurity -> "gini")
+  /**
+   * If true, the trained tree will undergo a pruning process after training, 
in which nodes
+   * with the same class predictions are merged. The resulting tree will be 
smaller and have
+   * faster predictions, but class probabilities will be lost.
+   * If false, no pruning is applied after training, and class probabilities 
are preserved.
+   * (default = true)
+   * @group param
+   */
+  final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+    "If true, the trained tree will undergo a pruning process after training, 
in which nodes" +
+    " with the same class predictions are merged. The resulting tree will be 
smaller and have" +
+    " faster predictions, but class probabilities will be lost." +
+    " If false, no pruning is applied after training, and class probabilities 
are preserved."

Review Comment:
   The `pruneTree` parameter description claims that "class probabilities will 
be lost" when pruning. However, pruning in `LearningNode.toNode(prune = true)` 
merges sibling leaf nodes but still constructs a `LeafNode` with a non-null 
`impurityStats` (used by `DecisionTreeClassificationModel.predictRaw`), so 
probability output remains available (though the tree structure changes). 
Please clarify this description to avoid misleading users about 
`probabilityCol` support.
   



##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -55,6 +55,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, 
Impurity, Variance}
  * @param minInfoGain Minimum information gain a split must get. Default value 
is 0.0.
  *                    If a split has less information gain than minInfoGain,
  *                    this split will not be considered as a valid split.
+ * @param pruneTree If this is true, the final training tree will undergo a 
pruning in which
+ *                  nodes with the same classifications are merged.

Review Comment:
   The new `pruneTree` Scaladoc says "nodes with the same classifications" are 
merged, but this flag is also used for regression trees (where there are no 
classifications). Consider updating the wording to be algorithm-agnostic (e.g., 
"same prediction") so the doc matches behavior for both classification and 
regression.
   



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,45 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
     val numClasses = 2
-    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 4,
-      numClasses = numClasses, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Classification,
+      impurity = Gini,
+      maxDepth = 4,
+      numClasses = numClasses,
+      maxBins = 32)
+
+    strategy.pruneTree = true
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = false
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = true
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)
+      .head
 
     assert(prunedTree.numNodes === 5)
     assert(unprunedTree.numNodes === 7)
 
+    assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
+

Review Comment:
   The new `defaultBehaviorTree` assertions don't actually validate the default 
value of `Strategy.pruneTree`: the test explicitly sets `strategy.pruneTree = 
true` immediately before training, and reuses a mutated `strategy` instance. To 
test the default, construct a fresh `OldStrategy` (without setting `pruneTree`) 
and compare its result to the `pruneTree = true` case.



##########
python/pyspark/ml/classification.py:
##########
@@ -1861,6 +1864,12 @@ def setMaxBins(self, value: int) -> 
"DecisionTreeClassifier":
         """
         return self._set(maxBins=value)
 
+    def setPruneTree(self, value: bool) -> "DecisionTreeClassifier":
+        """
+        Sets the value of :py:attr:`pruneTree`.
+        """
+        return self._set(pruneTree=value)
+

Review Comment:
   `setPruneTree` is a newly added public Python API but is missing an 
`@since("4.3.0")` annotation (while `getPruneTree` is annotated). Please add 
the matching `@since` decorator for consistency and documentation generation.



##########
python/pyspark/ml/tree.py:
##########
@@ -424,6 +431,12 @@ def getImpurity(self) -> str:
         Gets the value of impurity or its default value.
         """
         return self.getOrDefault(self.impurity)

Review Comment:
   There should be a blank line between `getImpurity` and the new 
`getPruneTree` method. As written, `@since("4.3.0")` is immediately after the 
previous method body, which violates PEP 8 spacing and may trip format/lint 
checks.
   



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -712,17 +854,45 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     )
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
-    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance, maxDepth = 4,
-      numClasses = 0, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Regression,
+      impurity = Variance,
+      maxDepth = 4,
+      numClasses = 0,
+      maxBins = 32)
+
+    strategy.pruneTree = true
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = false
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = true
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)
+      .head
 
     assert(prunedTree.numNodes === 3)
     assert(unprunedTree.numNodes === 5)
+
+    assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)
+

Review Comment:
   Same as the classification test: `defaultBehaviorTree` sets 
`strategy.pruneTree = true` and reuses the same `strategy`, so it doesn't 
verify the default value for regression. Use a new `OldStrategy` instance 
without touching `pruneTree` to assert the default behavior.



##########
python/pyspark/ml/tree.py:
##########
@@ -415,6 +415,13 @@ class _TreeClassifierParams(Params):
         typeConverter=TypeConverters.toString,
     )
 
+    pruneTree = Param(Params._dummy(), "pruneTree", "" +
+                      "If true, the trained tree will undergo a pruning 
process after training, in which nodes" +
+                      " with the same class predictions are merged. The 
resulting tree will be smaller and have" +
+                      " faster predictions, but class probabilities will be 
lost." +
+                      " If false, no pruning is applied after training, and 
class probabilities are preserved.",
+                      typeConverter=TypeConverters.toBoolean)

Review Comment:
   `pruneTree` is introduced without the `Param[bool]` type annotation used by 
other params in this file (e.g., `impurity: Param[str] = ...`), and the 
definition formatting is inconsistent with the surrounding style. Also, the 
description says class probabilities will be lost, but pruning retains 
`impurityStats` for probability outputs; consider aligning the text with actual 
behavior.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to