peter-toth commented on code in PR #52529:
URL: https://github.com/apache/spark/pull/52529#discussion_r2481988611
##########
sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala:
##########
@@ -976,6 +979,89 @@ class DataSourceV2Suite extends QueryTest with
SharedSparkSession with AdaptiveS
assert(result.length == 1)
}
}
+
+ test("SPARK-53809: scan canonicalization") {
+ val table = new
SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty())
+
+ def createDsv2ScanRelation(): DataSourceV2ScanRelation = {
+ val relation = DataSourceV2Relation.create(
+ table, None, None, CaseInsensitiveStringMap.empty())
+ val scan =
relation.table.asReadable.newScanBuilder(relation.options).build()
+ DataSourceV2ScanRelation(relation, scan, relation.output)
+ }
+
+ // Create two DataSourceV2ScanRelation instances, representing the scan of
the same table
+ val scanRelation1 = createDsv2ScanRelation()
+ val scanRelation2 = createDsv2ScanRelation()
+
+ // the two instances should not be the same, as they should have different
attribute IDs
+ assert(scanRelation1 != scanRelation2,
+ "Two created DataSourceV2ScanRelation instances should not be the same")
+ assert(scanRelation1.output.map(_.exprId).toSet !=
scanRelation2.output.map(_.exprId).toSet,
+ "Output attributes should have different expression IDs before
canonicalization")
+ assert(scanRelation1.relation.output.map(_.exprId).toSet !=
+ scanRelation2.relation.output.map(_.exprId).toSet,
+ "Relation output attributes should have different expression IDs before
canonicalization")
+
+ // After canonicalization, the two instances should be equal
+ assert(scanRelation1.canonicalized == scanRelation2.canonicalized,
+ "Canonicalized DataSourceV2ScanRelation instances should be equal")
+ }
+
+ test("SPARK-53809: check mergeScalarSubqueries is effective for
DataSourceV2ScanRelation") {
+ val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load()
+ df.createOrReplaceTempView("df")
+
+ val query = sql("select (select max(i) from df) as max_i, (select min(i)
from df) as min_i")
+ val optimizedPlan = query.queryExecution.optimizedPlan
+
+ // check optimizedPlan merged scalar subqueries `select max(i), min(i)
from df`
+ val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect {
+ case s: ScalarSubquery => s
+ }
+ val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect {
+ case s: ScalarSubquery => s
+ }
+
+ // Both subqueries should reference the same merged plan `select max(i),
min(i) from df`
+ assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should
exist")
+ assert(sub1.head.plan == sub2.head.plan,
+ "Both subqueries should reference the same merged plan")
+
+ // Extract the aggregate from the merged plan
+ val agg = sub1.head.plan.collect {
+ case a: Aggregate => a
+ }.head
+
+ // Check that the aggregate contains both max(i) and min(i)
+ val aggExprs = agg.aggregateExpressions
+
+ val hasMax = aggExprs.exists { expr =>
Review Comment:
Maybe it is better to extract the logic to the helper function or even
better to just collect the aggregate functions to a collection and test it
against the expected set.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]