Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21564#discussion_r195897400
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---
    @@ -686,6 +686,68 @@ class PlannerSuite extends SharedSQLContext {
           Range(1, 2, 1, 1)))
         df.queryExecution.executedPlan.execute()
       }
    +
    +  test("SPARK-24556: always rewrite output partitioning in 
InMemoryTableScanExec" +
    +    "and ReusedExchangeExec") {
    +    def checkOutputPartitioningRewrite(
    +        plans: Seq[SparkPlan],
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      plans.foreach { plan =>
    +        val partitioning = plan.outputPartitioning
    +        assert(partitioning.getClass == expectedPartitioningClass)
    +        val partitionedAttrs = 
partitioning.asInstanceOf[Expression].references
    +        assert(partitionedAttrs.subsetOf(plan.outputSet))
    +      }
    +    }
    +
    +    def checkInMemoryTableScanOutputPartitioningRewrite(
    +        df: DataFrame,
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      val inMemoryScans = df.queryExecution.executedPlan.collect {
    +        case m: InMemoryTableScanExec => m
    +      }
    +      checkOutputPartitioningRewrite(inMemoryScans, 
expectedPartitioningClass)
    +    }
    +
    +    def checkReusedExchangeOutputPartitioningRewrite(
    +        df: DataFrame,
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      val reusedExchange = df.queryExecution.executedPlan.collect {
    +        case r: ReusedExchangeExec => r
    +      }
    +      checkOutputPartitioningRewrite(reusedExchange, 
expectedPartitioningClass)
    +    }
    +
    +    // InMemoryTableScan is HashPartitioning
    +    val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
    +    val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
    +    checkInMemoryTableScanOutputPartitioningRewrite(df1.union(df2), 
classOf[HashPartitioning])
    +
    +    // InMemoryTableScan is RangePartitioning
    +    val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i").persist()
    +    val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i").persist()
    +    checkInMemoryTableScanOutputPartitioningRewrite(df3.union(df4), 
classOf[RangePartitioning])
    +
    +    // InMemoryTableScan is PartitioningCollection
    +    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
    +      val df5 =
    +        Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), 
$"i" === $"m").persist()
    +      val df6 =
    +        Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), 
$"i" === $"m").persist()
    +      checkInMemoryTableScanOutputPartitioningRewrite(
    +        df5.union(df6), classOf[PartitioningCollection])
    +    }
    +
    +    // ReusedExchange is HashPartitioning
    +    val df7 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
    +    val df8 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
    +    checkReusedExchangeOutputPartitioningRewrite(df7.union(df8), 
classOf[HashPartitioning])
    +
    +    // ReusedExchange is RangePartitioning
    +    val df9 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
    +    val df10 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
    --- End diff --
    
    Seems this test can be simplified. For example the difference between df3, 
df4 and df9, df10 is only `persist`. You can just define the dataframes and 
reuse them.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to