Github user yucai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21564#discussion_r195898387
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---
    @@ -686,6 +686,68 @@ class PlannerSuite extends SharedSQLContext {
           Range(1, 2, 1, 1)))
         df.queryExecution.executedPlan.execute()
       }
    +
    +  test("SPARK-24556: always rewrite output partitioning in 
InMemoryTableScanExec" +
    +    "and ReusedExchangeExec") {
    +    def checkOutputPartitioningRewrite(
    +        plans: Seq[SparkPlan],
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      plans.foreach { plan =>
    +        val partitioning = plan.outputPartitioning
    +        assert(partitioning.getClass == expectedPartitioningClass)
    +        val partitionedAttrs = 
partitioning.asInstanceOf[Expression].references
    +        assert(partitionedAttrs.subsetOf(plan.outputSet))
    +      }
    +    }
    +
    +    def checkInMemoryTableScanOutputPartitioningRewrite(
    +        df: DataFrame,
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      val inMemoryScans = df.queryExecution.executedPlan.collect {
    +        case m: InMemoryTableScanExec => m
    +      }
    +      checkOutputPartitioningRewrite(inMemoryScans, 
expectedPartitioningClass)
    +    }
    +
    +    def checkReusedExchangeOutputPartitioningRewrite(
    +        df: DataFrame,
    +        expectedPartitioningClass: Class[_]): Unit = {
    +      val reusedExchange = df.queryExecution.executedPlan.collect {
    +        case r: ReusedExchangeExec => r
    +      }
    +      checkOutputPartitioningRewrite(reusedExchange, 
expectedPartitioningClass)
    +    }
    +
    +    // InMemoryTableScan is HashPartitioning
    +    val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
    +    val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
    +    checkInMemoryTableScanOutputPartitioningRewrite(df1.union(df2), 
classOf[HashPartitioning])
    +
    +    // InMemoryTableScan is RangePartitioning
    +    val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i").persist()
    --- End diff --
    
    I want `RangePartitioning` here, so using `orderBy`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to