Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19196#discussion_r138771978
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 ---
    @@ -381,4 +388,233 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
           AddData(streamInput, 0, 1, 2, 3),
           CheckLastBatch((0, 0, 2), (1, 1, 3)))
       }
    +
    +  private def checkAggregationChain(
    +      sq: StreamingQuery,
    +      requiresShuffling: Boolean,
    +      expectedPartition: Int): Unit = {
    +    val executedPlan = 
sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
    +      .lastExecution.executedPlan
    +    val restore = executedPlan
    +      .collect { case ss: StateStoreRestoreExec => ss }
    +      .head
    +    restore.child match {
    +      case node: UnaryExecNode =>
    +        assert(node.outputPartitioning.numPartitions === expectedPartition)
    +        if (requiresShuffling) {
    +          assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: 
${node.child}")
    +        } else {
    +          assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
    +        }
    +
    +      case _ => fail("Expected no shuffling")
    +    }
    +    var reachedRestore = false
    +    // Check that there should be no exchanges after 
`StateStoreRestoreExec`
    +    executedPlan.foreachUp { p =>
    +      if (reachedRestore) {
    +        assert(!p.isInstanceOf[Exchange], "There should be no further 
exchanges")
    +      } else {
    +        reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
    +      }
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) with 0 partition RDD should be 
repartitioned accordingly") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +        val aggregated: Dataset[Long] =
    +          spark.readStream
    +            .format((new MockSourceProvider).getClass.getCanonicalName)
    +            .load()
    +            .coalesce(1)
    +            .groupBy()
    +            .count()
    +            .as[Long]
    +
    +        val sq = aggregated.writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +
    +        try {
    +
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +
    +          checkAggregationChain(sq, requiresShuffling = false, 1)
    +
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkAggregationChain(sq, requiresShuffling = true, 1)
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +
    +          inputSource.addData(2, 3)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +        } finally {
    +          sq.stop()
    +        }
    +      }
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) should still be repartitioned when it has 
keyExpressions") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +
    +        val sq = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(1)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +
    +        try {
    +
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkAggregationChain(
    +            sq,
    +            requiresShuffling = true,
    +            spark.sessionState.conf.numShufflePartitions)
    +
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 1L))
    +
    +        } finally {
    +          sq.stop()
    +        }
    +
    +        val sq2 = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(2)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    --- End diff --
    
    this query code can be deduped into a function


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to