Copilot commented on code in PR #55758:
URL: https://github.com/apache/spark/pull/55758#discussion_r3207902691


##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
     (value: String) =>
       
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
-  setDefault(impurity -> "gini")
+  /**
+   * If true, the trained tree will undergo a pruning process after training, 
in which nodes
+   * with the same class predictions are merged. The resulting tree will be 
smaller and have
+   * faster predictions, but class probabilities will be lost.
+   * If false, no pruning is applied after training, and class probabilities 
are preserved.
+   * (default = true)
+   * @group param
+   */
+  final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+    "If true, the trained tree will undergo a pruning process after training, 
in which nodes" +
+    " with the same class predictions are merged. The resulting tree will be 
smaller and have" +
+    " faster predictions, but class probabilities will be lost." +
+    " If false, no pruning is applied after training, and class probabilities 
are preserved."

Review Comment:
   The docstring claims pruning means “class probabilities will be lost” and 
“preserved” when disabled, but pruning here only merges sibling leaves with 
identical predicted class; the resulting merged leaf still carries an 
ImpurityCalculator (and thus probabilities), though probability estimates 
become less fine-grained. Please reword this param description to avoid 
implying probabilities are unavailable/undefined after pruning.
   



##########
mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala:
##########
@@ -211,10 +211,27 @@ private[ml] trait TreeClassifierParams extends Params {
     (value: String) =>
       
TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
-  setDefault(impurity -> "gini")
+  /**
+   * If true, the trained tree will undergo a pruning process after training, 
in which nodes
+   * with the same class predictions are merged. The resulting tree will be 
smaller and have
+   * faster predictions, but class probabilities will be lost.
+   * If false, no pruning is applied after training, and class probabilities 
are preserved.
+   * (default = true)
+   * @group param
+   */
+  final val pruneTree: BooleanParam = new BooleanParam(this, "pruneTree", "" +
+    "If true, the trained tree will undergo a pruning process after training, 
in which nodes" +
+    " with the same class predictions are merged. The resulting tree will be 
smaller and have" +
+    " faster predictions, but class probabilities will be lost." +
+    " If false, no pruning is applied after training, and class probabilities 
are preserved."

Review Comment:
   Adding pruneTree introduces new user-facing behavior, but the current tests 
updated here only exercise the old API (OldStrategy + impl.RandomForest). 
Please add ML-level unit tests (e.g., DecisionTreeClassifierSuite / 
RandomForestClassifierSuite) verifying setPruneTree(false) affects the 
resulting ML model shape/probabilities and that the default matches 
pruneTree=true.



##########
python/pyspark/ml/classification.py:
##########
@@ -1861,6 +1864,12 @@ def setMaxBins(self, value: int) -> 
"DecisionTreeClassifier":
         """
         return self._set(maxBins=value)
 
+    def setPruneTree(self, value: bool) -> "DecisionTreeClassifier":
+        """
+        Sets the value of :py:attr:`pruneTree`.
+        """
+        return self._set(pruneTree=value)
+

Review Comment:
   This new public setter should be annotated with `@since("4.3.0")` (matching 
the new param/getter) so the Python API is correctly versioned in the generated 
docs.
   



##########
mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala:
##########
@@ -113,12 +116,13 @@ class Strategy @Since("1.3.0") (
       categoricalFeaturesInfo: Map[Int, Int],
       minInstancesPerNode: Int,
       minInfoGain: Double,

Review Comment:
   This change modifies the existing “backwards compatible” constructor 
signature by inserting `pruneTree`, which can break binary compatibility for 
downstream code compiled against the previous constructor. Please keep the old 
constructor overload (without pruneTree) delegating to the new one with 
`pruneTree = true`, and/or avoid adding new fields to the primary constructor 
for public APIs.
   



##########
python/pyspark/ml/tree.py:
##########
@@ -424,6 +431,12 @@ def getImpurity(self) -> str:
         Gets the value of impurity or its default value.
         """
         return self.getOrDefault(self.impurity)
+    @since("4.3.0")
+    def getPruneTree(self) -> bool:
+        """
+        Gets the value of pruneTree or its default value.
+        """
+        return self.getOrDefault(self.pruneTree)

Review Comment:
   For consistency with the rest of this module’s param declarations, 
`pruneTree` should be type-annotated (e.g., `Param[bool]`) and separated from 
the previous method with a blank line (PEP8) before the `@since` decorator. 
This improves static typing and keeps formatting consistent with nearby 
params/getters.



##########
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala:
##########
@@ -682,18 +797,45 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val rdd = sc.parallelize(arr.toImmutableArraySeq)
 
     val numClasses = 2
-    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = 
Gini, maxDepth = 4,
-      numClasses = numClasses, maxBins = 32)
-
-    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None).head
-
-    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
-      seed = 42, instr = None, prune = false).head
+    val strategy = new OldStrategy(
+      algo = OldAlgo.Classification,
+      impurity = Gini,
+      maxDepth = 4,
+      numClasses = numClasses,
+      maxBins = 32)
+
+    strategy.pruneTree = true
+    val prunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = false
+    val unprunedTree = RandomForest
+      .run(
+        rdd,
+        strategy,
+        numTrees = 1,
+        featureSubsetStrategy = "auto",
+        seed = 42,
+        instr = None)
+      .head
+
+    strategy.pruneTree = true
+    val defaultBehaviorTree = RandomForest
+      .run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", seed = 
42, instr = None)
+      .head
 
     assert(prunedTree.numNodes === 5)
     assert(unprunedTree.numNodes === 7)
 
+    assert(defaultBehaviorTree.numNodes == prunedTree.numNodes)

Review Comment:
   These new assertions use `==` while surrounding assertions in this suite use 
ScalaTest’s `===` for better diagnostics and to avoid accidental reference 
equality pitfalls. Please use `===` here for consistency with the rest of the 
file.
   



##########
python/pyspark/ml/classification.py:
##########
@@ -2163,6 +2175,12 @@ def setMaxBins(self, value: int) -> 
"RandomForestClassifier":
         """
         return self._set(maxBins=value)
 
+    def setPruneTree(self, value: bool) -> "RandomForestClassifier":
+        """
+        Sets the value of :py:attr:`pruneTree`.
+        """
+        return self._set(pruneTree=value)
+

Review Comment:
   This new public setter should be annotated with `@since("4.3.0")` (matching 
the new param/getter) so the Python API is correctly versioned in the generated 
docs.



##########
python/pyspark/ml/tree.py:
##########
@@ -415,6 +415,13 @@ class _TreeClassifierParams(Params):
         typeConverter=TypeConverters.toString,
     )
 
+    pruneTree = Param(Params._dummy(), "pruneTree", "" +
+                      "If true, the trained tree will undergo a pruning 
process after training, in which nodes" +
+                      " with the same class predictions are merged. The 
resulting tree will be smaller and have" +
+                      " faster predictions, but class probabilities will be 
lost." +
+                      " If false, no pruning is applied after training, and 
class probabilities are preserved.",

Review Comment:
   The pruneTree param description says “class probabilities will be lost,” but 
pruning in Spark merges sibling leaves with the same predicted class and still 
retains an impurity calculator (so probabilities remain available, though less 
fine-grained). Please reword to reflect that pruning may change/coarsen 
probability estimates rather than removing probabilities entirely.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to