LuciferYang commented on code in PR #56604:
URL: https://github.com/apache/spark/pull/56604#discussion_r3440065104
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala:
##########
@@ -1225,6 +1225,48 @@ class WholeStageCodegenSuite extends SharedSparkSession
"CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting
non-CSE path)")
}
+ test("SPARK-56032: FilterExec skips CSE codegen when the only common
subexpression is a leaf") {
+ // `c BETWEEN lo AND hi` lowers to `c >= lo AND c <= hi`, so a column used
in a BETWEEN (or in
+ // several conjuncts) becomes a "common subexpression" -- but it is a bare
leaf column whose
+ // load CSE cannot meaningfully cache, since the non-CSE path already
loads columns lazily. The
+ // gate must not take the CSE path for it: doing so emits the eager
prologue that decodes every
+ // referenced column (the decimals `p1`/`p2` here) up front, defeating the
cheap `q` filter's
+ // short-circuiting. This is the TPC-DS q28 shape. Verify the leaf-only
case falls back to the
+ // same code as CSE-disabled; `p1`/`p2` stand in for q28's decimal columns
whose eager decode
+ // is the cost, though the fallback is type-independent.
+ val schema = StructType(Seq(
+ StructField("q", IntegerType, nullable = true),
+ StructField("p1", IntegerType, nullable = true),
+ StructField("p2", IntegerType, nullable = true)))
+ val data = spark.sparkContext.parallelize(Seq(
+ Row(4, 10, 7), Row(1, 10, 7), Row(null, 10, 7),
+ Row(5, 100, 7), Row(6, 100, 100), Row(3, 9, 1)))
+ val expected = Seq(Row(4, 10, 7), Row(5, 100, 7), Row(3, 9, 1))
+
+ def filterCode(cseEnabled: Boolean): String = {
+ withSQLConf(
+ SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString,
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val df = spark.createDataFrame(data, schema)
+ // The only repeated expressions are the bare columns q, p1, p2 (each
referenced by the two
+ // halves of its BETWEEN). No non-leaf expression is shared.
+ val filtered = df.where(
Review Comment:
Let me verify the validity of this test.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]