Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/21564#discussion_r195897400
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---
@@ -686,6 +686,68 @@ class PlannerSuite extends SharedSQLContext {
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
}
+
+ test("SPARK-24556: always rewrite output partitioning in
InMemoryTableScanExec" +
+ "and ReusedExchangeExec") {
+ def checkOutputPartitioningRewrite(
+ plans: Seq[SparkPlan],
+ expectedPartitioningClass: Class[_]): Unit = {
+ plans.foreach { plan =>
+ val partitioning = plan.outputPartitioning
+ assert(partitioning.getClass == expectedPartitioningClass)
+ val partitionedAttrs =
partitioning.asInstanceOf[Expression].references
+ assert(partitionedAttrs.subsetOf(plan.outputSet))
+ }
+ }
+
+ def checkInMemoryTableScanOutputPartitioningRewrite(
+ df: DataFrame,
+ expectedPartitioningClass: Class[_]): Unit = {
+ val inMemoryScans = df.queryExecution.executedPlan.collect {
+ case m: InMemoryTableScanExec => m
+ }
+ checkOutputPartitioningRewrite(inMemoryScans,
expectedPartitioningClass)
+ }
+
+ def checkReusedExchangeOutputPartitioningRewrite(
+ df: DataFrame,
+ expectedPartitioningClass: Class[_]): Unit = {
+ val reusedExchange = df.queryExecution.executedPlan.collect {
+ case r: ReusedExchangeExec => r
+ }
+ checkOutputPartitioningRewrite(reusedExchange,
expectedPartitioningClass)
+ }
+
+ // InMemoryTableScan is HashPartitioning
+ val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
+ val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(df1.union(df2),
classOf[HashPartitioning])
+
+ // InMemoryTableScan is RangePartitioning
+ val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i").persist()
+ val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i").persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(df3.union(df4),
classOf[RangePartitioning])
+
+ // InMemoryTableScan is PartitioningCollection
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
+ val df5 =
+ Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"),
$"i" === $"m").persist()
+ val df6 =
+ Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"),
$"i" === $"m").persist()
+ checkInMemoryTableScanOutputPartitioningRewrite(
+ df5.union(df6), classOf[PartitioningCollection])
+ }
+
+ // ReusedExchange is HashPartitioning
+ val df7 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
+ val df8 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
+ checkReusedExchangeOutputPartitioningRewrite(df7.union(df8),
classOf[HashPartitioning])
+
+ // ReusedExchange is RangePartitioning
+ val df9 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
+ val df10 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
--- End diff --
Seems this test can be simplified. For example the difference between df3,
df4 and df9, df10 is only `persist`. You can just define the dataframes and
reuse them.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]