LuciferYang commented on code in PR #55925:
URL: https://github.com/apache/spark/pull/55925#discussion_r3442750667


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -1427,6 +1366,18 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED =
+    buildConf("spark.sql.optimizer.rewriteCountDistinctConditional.enabled")
+      .doc("When true, rewrites COUNT(DISTINCT IF(cond, base, NULL)) and " +
+        "COUNT(DISTINCT CASE WHEN cond THEN base END) into " +
+        "COUNT(DISTINCT base) FILTER (WHERE cond). This reduces the Expand 
factor " +
+        "in RewriteDistinctAggregates from Nx to 1x when multiple conditional 
distinct " +
+        "counts share the same base column.")
+      .version("4.2.0")

Review Comment:
   I think this should be `4.3.0`



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala:
##########
@@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
     }
   }
 
+  /**
+   * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
+   * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) 
FILTER (WHERE cond).
+   * This reduces the number of distinct groups: multiple conditional counts 
on the same base
+   * column collapse into one group, shrinking the Expand fan-out from Nx to 
1x.
+   */
+  private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
+    if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
+    a.transformExpressionsUp {
+      case ae @ AggregateExpression(count: Count, _, true, None, _)
+          if count.children.size == 1 =>
+        extractCondAndBase(count.children.head) match {
+          case Some((cond, base)) =>
+            ae.copy(
+              aggregateFunction = 
count.withNewChildren(Seq(base)).asInstanceOf[Count],

Review Comment:
   `count.withNewChildren(Seq(base)).asInstanceOf[Count]` works, but 
`Count(base)` is more direct — `Count` is a `case class` and the rewrite 
produces a brand-new node with one child. The `asInstanceOf` is only needed to 
satisfy `withNewChildren`'s `BaseType` return. 



##########
sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+class RewriteDistinctAggregatesConditionalQuerySuite extends QueryTest with 
SharedSparkSession {
+
+  private def checkRewriteAndResult(
+      conditionalSql: String,
+      filterSql: String): Unit = {
+    withTempView("t") {
+      // Verify the rewrite produces the same result as the explicit FILTER 
form.
+      val withRewrite = withSQLConf(
+        SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
+        spark.sql(conditionalSql).collect()
+      }
+      val withoutRewrite = withSQLConf(
+        SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") {
+        spark.sql(conditionalSql).collect()
+      }
+      val explicitFilter = spark.sql(filterSql).collect()
+
+      assert(withRewrite.sameElements(explicitFilter),
+        "Rewritten query should match explicit FILTER query")
+      assert(withoutRewrite.sameElements(explicitFilter),
+        "Non-rewritten query should also match explicit FILTER query")
+    }
+  }
+
+  test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) correctness") {
+    withTempView("t") {

Review Comment:
   Each test redundantly nests `withTempView("t") { ... 
createOrReplaceTempView("t") }`, then `checkRewriteAndResult` opens another 
`withTempView("t")`. The outer wrapper is a no-op. Drop the outer one (inner 
cleanup covers it), or move view registration into `checkRewriteAndResult` 
parameterized by the seed DataFrame.



##########
sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+class RewriteDistinctAggregatesConditionalQuerySuite extends QueryTest with 
SharedSparkSession {
+
+  private def checkRewriteAndResult(
+      conditionalSql: String,
+      filterSql: String): Unit = {
+    withTempView("t") {
+      // Verify the rewrite produces the same result as the explicit FILTER 
form.
+      val withRewrite = withSQLConf(
+        SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
+        spark.sql(conditionalSql).collect()
+      }
+      val withoutRewrite = withSQLConf(
+        SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") {
+        spark.sql(conditionalSql).collect()
+      }
+      val explicitFilter = spark.sql(filterSql).collect()
+
+      assert(withRewrite.sameElements(explicitFilter),
+        "Rewritten query should match explicit FILTER query")
+      assert(withoutRewrite.sameElements(explicitFilter),
+        "Non-rewritten query should also match explicit FILTER query")
+    }
+  }
+
+  test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) correctness") {
+    withTempView("t") {
+      spark.range(7)
+        .selectExpr(
+          "cast(id % 3 + 1 as int) as key",
+          "cast(id * 10 as int) as col1",
+          "case when id % 4 = 0 then null else cast(id * 100 as int) end as 
col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY 
key",
+        "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t 
GROUP BY key")
+    }
+  }
+
+  test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) correctness") {
+    withTempView("t") {
+      spark.range(7)
+        .selectExpr(
+          "cast(id % 3 + 1 as int) as key",
+          "cast(id * 10 as int) as col1",
+          "case when id % 4 = 0 then null else cast(id * 100 as string) end as 
col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        "SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 END) FROM t 
GROUP BY key",
+        "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t 
GROUP BY key")
+    }
+  }
+
+  test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END) 
correctness") {
+    withTempView("t") {
+      spark.range(6)
+        .selectExpr(
+          "cast(id % 2 + 1 as int) as key",
+          "cast(id * 10 as int) as col1",
+          "case when id % 4 = 0 then null else cast(id * 1.0 as double) end as 
col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        """SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 ELSE NULL 
END)
+          |FROM t GROUP BY key""".stripMargin,
+        "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t 
GROUP BY key")
+    }
+  }
+
+  test("rewrite with no GROUP BY") {
+    withTempView("t") {
+      spark.range(5)
+        .selectExpr(
+          "cast(id * 10 as int) as col1",
+          "case when id % 3 = 0 then null else cast(id * 100 as int) end as 
col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        "SELECT COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t",
+        "SELECT COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t")
+    }
+  }
+
+  test("rewrite with all NULLs in conditional branch") {
+    withTempView("t") {
+      spark.range(3)
+        .selectExpr(
+          "cast(id % 2 + 1 as int) as key",
+          "cast(id * 5 as int) as col1",
+          "cast(id * 100 as int) as col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY 
key",
+        "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t 
GROUP BY key")
+    }
+  }
+
+  test("rewrite with duplicates in base column") {
+    withTempView("t") {
+      spark.range(6)
+        .selectExpr(
+          "cast(id % 2 + 1 as int) as key",
+          "cast(id * 10 as int) as col1",
+          "case when id % 3 = 0 then 100 when id % 3 = 1 then 100 else 200 end 
as col2")
+        .createOrReplaceTempView("t")
+
+      checkRewriteAndResult(
+        "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY 
key",
+        "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t 
GROUP BY key")
+    }
+  }
+
+  test("multiple conditional distinct counts collapse and produce correct 
results") {
+    withTempView("t") {
+      spark.range(5)
+        .selectExpr(
+          "cast(id % 2 + 1 as int) as key",
+          "cast(id * 10 as int) as col1",
+          "case when id % 3 = 0 then null else cast(id * 100 as int) end as 
col2",
+          "case when id % 4 = 0 then null else cast(id * 10 as string) end as 
col3")
+        .createOrReplaceTempView("t")
+
+      val conditionalSql =
+        """SELECT key,
+          |  COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1,
+          |  COUNT(DISTINCT IF(col1 > 5, col3, NULL)) as cnt2
+          |FROM t GROUP BY key""".stripMargin
+
+      val filterSql =
+        """SELECT key,
+          |  COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) as cnt1,
+          |  COUNT(DISTINCT col3) FILTER (WHERE col1 > 5) as cnt2
+          |FROM t GROUP BY key""".stripMargin
+
+      checkRewriteAndResult(conditionalSql, filterSql)
+    }
+  }
+
+  test("rewrite does not affect COUNT(DISTINCT IF(cond, col, non_null))") {

Review Comment:
   this test only checks `withRewrite.sameElements(withoutRewrite)`. That's 
vacuously true if the rewrite never fires AND the rewrite-on path produces the 
same result — but it doesn't verify the plan really lacked the conversion. A 
future bug that incorrectly rewrites the non-null else case (semantically 
wrong: `COUNT(DISTINCT IF(c, x, 0))` → `COUNT(DISTINCT x) FILTER (WHERE c)` 
drops `0` from the distinct set when c=false) could still pass on small 
datasets. Strengthen by also asserting `optimizedPlan.toString` does NOT 
contain `FILTER` for this query.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala:
##########
@@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
     }
   }
 
+  /**
+   * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
+   * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) 
FILTER (WHERE cond).
+   * This reduces the number of distinct groups: multiple conditional counts 
on the same base
+   * column collapse into one group, shrinking the Expand fan-out from Nx to 
1x.
+   */
+  private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
+    if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
+    a.transformExpressionsUp {
+      case ae @ AggregateExpression(count: Count, _, true, None, _)
+          if count.children.size == 1 =>
+        extractCondAndBase(count.children.head) match {
+          case Some((cond, base)) =>
+            ae.copy(
+              aggregateFunction = 
count.withNewChildren(Seq(base)).asInstanceOf[Count],
+              filter = Some(cond))
+          case None => ae
+        }
+    }.asInstanceOf[Aggregate]
+  }
+
+  /**
+   * Matches IF(cond, base, null), CASE WHEN cond THEN base END, and

Review Comment:
   Doc gap: `extractCondAndBase` deliberately rejects multi-branch `CaseWhen` 
(`CASE WHEN c1 THEN base WHEN c2 THEN base END`), even though that's 
semantically `Or(c1, c2) THEN base`. The existing `do not rewrite multi-branch 
CASE WHEN` test confirms intent. One-line in the docstring: "Multi-branch 
`CaseWhen` is intentionally not rewritten — Or-flattening is out of scope."



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala:
##########
@@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
       fail(s"Plan is not as expected:\n$rewrite")
     }
   }
+
+  // 
---------------------------------------------------------------------------
+  // COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898)
+  // 
---------------------------------------------------------------------------
+
+  val conditionalTestRelation = LocalRelation(
+    Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string)
+
+  private def countDistinctIf(cond: Expression, base: Expression): Expression 
= {
+    Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = 
true)
+  }
+
+  private def countDistinctCaseWhen(cond: Expression, base: Expression): 
Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      None)
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  private def countDistinctCaseWhenElseNull(cond: Expression, base: 
Expression): Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      Some(Literal(null)))
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  /**
+   * Asserts that the optimized plan has exactly one Expand node with one 
projection,
+   * that the projection contains `baseColName` as a plain attribute, and that 
it
+   * contains no expression of `removedWrapperType` (the IF/CaseWhen that was 
stripped).
+   */
+  private def assertSingleDistinctGroupExpand(
+      optimized: LogicalPlan,
+      baseColName: String,
+      removedWrapperType: Class[_]): Unit = {
+    val expand = optimized.collectFirst { case e: Expand => e }.get
+    assert(expand.projections.size == 1,
+      s"expected 1 distinct group but got ${expand.projections.size}")
+    val baseAttr = conditionalTestRelation.output.find(_.name == 
baseColName).get
+    assert(expand.projections.head.exists(_.semanticEquals(baseAttr)),
+      s"expected base column $baseColName in Expand projection")
+    assert(!expand.projections.head.exists(e => 
removedWrapperType.isInstance(e)),
+      s"${removedWrapperType.getSimpleName} wrapper should have been removed " 
+
+        "from Expand projection")
+  }
+
+  test("conditional: disabled by default") {
+    val input = conditionalTestRelation
+      .groupBy(Symbol("a"))(
+        countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"))
+      .analyze
+    val optimized = RewriteDistinctAggregates(input)
+    comparePlans(optimized, input)
+  }
+
+  test("conditional: rewrite COUNT(DISTINCT IF(cond, col, NULL)) to 
COUNT(DISTINCT col) FILTER") {
+    val input = conditionalTestRelation
+      .groupBy(Symbol("a"))(
+        countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"),
+        countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"))
+      .analyze
+
+    withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> 
"true") {
+      val optimized = RewriteDistinctAggregates(input)
+      assertSingleDistinctGroupExpand(optimized, "c", classOf[If])
+    }
+  }
+
+  test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to " +
+      "COUNT(DISTINCT col) FILTER") {
+    val input = conditionalTestRelation
+      .groupBy(Symbol("a"))(
+        countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1"),
+        countDistinctCaseWhen(Symbol("b") > 2, Symbol("c")).as("cnt2"))
+      .analyze
+
+    withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> 
"true") {
+      val optimized = RewriteDistinctAggregates(input)
+      assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen])
+    }
+  }
+
+  test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL 
END)") {
+    val input = conditionalTestRelation
+      .groupBy(Symbol("a"))(
+        countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1"),
+        countDistinctCaseWhenElseNull(Symbol("b") > 2, Symbol("c")).as("cnt2"))
+      .analyze
+
+    withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> 
"true") {
+      val optimized = RewriteDistinctAggregates(input)
+      assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen])
+    }
+  }
+
+  test("conditional: multiple conditional distinct counts collapse to single 
distinct group") {
+    val input = conditionalTestRelation
+      .groupBy(Symbol("a"))(
+        countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"),
+        countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"),
+        countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3"))
+      .analyze
+
+    withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> 
"true") {
+      val optimized = RewriteDistinctAggregates(input)
+      // All three counts share the same base column c, collapsed to 1 
distinct group.
+      assertSingleDistinctGroupExpand(optimized, "c", classOf[If])
+    }
+  }
+
+  test("conditional: single conditional distinct count does not produce 
Expand") {

Review Comment:
   Same class of weakness as `conditional: disabled by default` — verified 
empirically that this test's assertion (`expands.isEmpty`) holds with both 
`conf=off` and `conf=on`, because `mayNeedtoRewrite` returns false on a lone 
conditional distinct count. The test verifies `mayNeedtoRewrite`'s pre-existing 
behavior, not the new conf or helper. Either rename to something like `single 
conditional distinct count is gated out by mayNeedtoRewrite`, or wrap in 
`withSQLConf(... -> "false")` and add a paired `conf=on` case asserting the 
same outcome — to document the conf-independence.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala:
##########
@@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
       fail(s"Plan is not as expected:\n$rewrite")
     }
   }
+
+  // 
---------------------------------------------------------------------------
+  // COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898)
+  // 
---------------------------------------------------------------------------
+
+  val conditionalTestRelation = LocalRelation(
+    Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string)
+
+  private def countDistinctIf(cond: Expression, base: Expression): Expression 
= {
+    Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = 
true)
+  }
+
+  private def countDistinctCaseWhen(cond: Expression, base: Expression): 
Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      None)
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  private def countDistinctCaseWhenElseNull(cond: Expression, base: 
Expression): Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      Some(Literal(null)))
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  /**
+   * Asserts that the optimized plan has exactly one Expand node with one 
projection,
+   * that the projection contains `baseColName` as a plain attribute, and that 
it
+   * contains no expression of `removedWrapperType` (the IF/CaseWhen that was 
stripped).
+   */
+  private def assertSingleDistinctGroupExpand(

Review Comment:
   This helper checks that the Expand has 1 projection, that `baseColName` is 
in it, and that the wrapper class (`If` / `CaseWhen`) is gone. It does NOT 
check that the rewrite actually moved `cond` into a `FILTER` clause. A future 
bug that strips the wrapper but forgets to populate `filter` would still pass. 
Verified: rewritten plan has 2 `Count(c)` aggs with `filter = Some(...)` and 2 
`Max(...)` without; plan `toString` contains the literal `FILTER`. Add: 
`optimized.toString.contains("FILTER")`, or richer `optimized.collect { case 
ae: AggregateExpression => ae }.exists(_.filter.isDefined)`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala:
##########
@@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
     }
   }
 
+  /**
+   * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
+   * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) 
FILTER (WHERE cond).
+   * This reduces the number of distinct groups: multiple conditional counts 
on the same base
+   * column collapse into one group, shrinking the Expand fan-out from Nx to 
1x.
+   */
+  private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
+    if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
+    a.transformExpressionsUp {
+      case ae @ AggregateExpression(count: Count, _, true, None, _)
+          if count.children.size == 1 =>
+        extractCondAndBase(count.children.head) match {
+          case Some((cond, base)) =>
+            ae.copy(
+              aggregateFunction = 
count.withNewChildren(Seq(base)).asInstanceOf[Count],
+              filter = Some(cond))
+          case None => ae
+        }
+    }.asInstanceOf[Aggregate]

Review Comment:
   `transformExpressionsUp` is declared `: this.type` (`QueryPlan.scala:231`). 
Called on a value typed `Aggregate`, the result is already typed `Aggregate` — 
the trailing `.asInstanceOf[Aggregate]` is redundant. 



##########
sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+class RewriteDistinctAggregatesConditionalQuerySuite extends QueryTest with 
SharedSparkSession {
+
+  private def checkRewriteAndResult(

Review Comment:
   `checkRewriteAndResult` extends `QueryTest` but bypasses `checkAnswer` and 
uses raw `Array.sameElements`. `sameElements` is order-sensitive; the test SQLs 
use `GROUP BY` without `ORDER BY`, so output ordering is not part of Spark's 
contract. Tests pass today by accident of partition layout. Use `checkAnswer` 
(which sorts before comparing). `QueryTest` already provides a `checkAnswer(df, 
expectedDf)` overload that takes two DataFrames directly.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala:
##########
@@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
       fail(s"Plan is not as expected:\n$rewrite")
     }
   }
+
+  // 
---------------------------------------------------------------------------
+  // COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898)
+  // 
---------------------------------------------------------------------------
+
+  val conditionalTestRelation = LocalRelation(
+    Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string)
+
+  private def countDistinctIf(cond: Expression, base: Expression): Expression 
= {
+    Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = 
true)
+  }
+
+  private def countDistinctCaseWhen(cond: Expression, base: Expression): 
Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      None)
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  private def countDistinctCaseWhenElseNull(cond: Expression, base: 
Expression): Expression = {
+    val caseWhen = CaseWhen(
+      Seq((cond, base)),
+      Some(Literal(null)))
+    Count(caseWhen).toAggregateExpression(isDistinct = true)
+  }
+
+  /**
+   * Asserts that the optimized plan has exactly one Expand node with one 
projection,
+   * that the projection contains `baseColName` as a plain attribute, and that 
it
+   * contains no expression of `removedWrapperType` (the IF/CaseWhen that was 
stripped).
+   */
+  private def assertSingleDistinctGroupExpand(
+      optimized: LogicalPlan,
+      baseColName: String,
+      removedWrapperType: Class[_]): Unit = {
+    val expand = optimized.collectFirst { case e: Expand => e }.get
+    assert(expand.projections.size == 1,
+      s"expected 1 distinct group but got ${expand.projections.size}")
+    val baseAttr = conditionalTestRelation.output.find(_.name == 
baseColName).get
+    assert(expand.projections.head.exists(_.semanticEquals(baseAttr)),
+      s"expected base column $baseColName in Expand projection")
+    assert(!expand.projections.head.exists(e => 
removedWrapperType.isInstance(e)),
+      s"${removedWrapperType.getSimpleName} wrapper should have been removed " 
+
+        "from Expand projection")
+  }
+
+  test("conditional: disabled by default") {

Review Comment:
   this test doesn't actually exercise the disabled state. The input has one 
conditional distinct count, which makes `mayNeedtoRewrite` return false; 
`rewrite()` is never called and the conf is never consulted. Verified 
empirically: with single-conditional input, conf=default(false) and conf=true 
both produce zero `Expand` nodes. To prove "disabled by default": switch to 
multi-conditional input and assert 2 distinct groups when conf is at its 
default.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala:
##########
@@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
     }
   }
 
+  /**
+   * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
+   * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) 
FILTER (WHERE cond).
+   * This reduces the number of distinct groups: multiple conditional counts 
on the same base
+   * column collapse into one group, shrinking the Expand fan-out from Nx to 
1x.
+   */
+  private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
+    if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
+    a.transformExpressionsUp {
+      case ae @ AggregateExpression(count: Count, _, true, None, _)
+          if count.children.size == 1 =>
+        extractCondAndBase(count.children.head) match {
+          case Some((cond, base)) =>
+            ae.copy(
+              aggregateFunction = 
count.withNewChildren(Seq(base)).asInstanceOf[Count],
+              filter = Some(cond))
+          case None => ae

Review Comment:
   Style: `case None => ae` returning the unchanged input is structural noise — 
`transformExpressionsUp` already keeps any expression the partial function 
doesn't rewrite. Pull the option out: fold the 
`extractCondAndBase(...).isDefined` check into the case guard and inline `.get` 
in the body. 



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala:
##########
@@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
     }
   }
 
+  /**
+   * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
+   * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) 
FILTER (WHERE cond).
+   * This reduces the number of distinct groups: multiple conditional counts 
on the same base
+   * column collapse into one group, shrinking the Expand fan-out from Nx to 
1x.
+   */
+  private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
+    if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
+    a.transformExpressionsUp {

Review Comment:
   Use `transformExpressionsUpWithPruning(_.containsAnyPattern(IF, CASE_WHEN))` 
so aggregates with no IF/CaseWhen short-circuit. Negligible perf win for 
typical workloads, but matches the convention in sister rules (e.g. 
`SimplifyConditionals` at `expressions.scala:635`).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to