Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19196#discussion_r139078823
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
---
@@ -381,4 +388,187 @@ class StreamingAggregationSuite extends
StateStoreMetricsTest
AddData(streamInput, 0, 1, 2, 3),
CheckLastBatch((0, 0, 2), (1, 1, 3)))
}
+
+ /**
+ * This method verifies certain properties in the SparkPlan of a
streaming aggregation.
+ * First of all, it checks that the child of a `StateStoreRestoreExec`
creates the desired
+ * data distribution, where the child could be an Exchange, or a
`HashAggregateExec` which already
+ * provides the expected data distribution.
+ *
+ * The second thing it checks that the child provides the expected
number of partitions.
+ *
+ * The third thing it checks that we don't add an unnecessary shuffle
in-between
+ * `StateStoreRestoreExec` and `StateStoreSaveExec`.
+ */
+ private def checkAggregationChain(
+ se: StreamExecution,
+ expectShuffling: Boolean,
+ expectedPartition: Int): Boolean = {
+ val executedPlan = se.lastExecution.executedPlan
+ val restore = executedPlan
+ .collect { case ss: StateStoreRestoreExec => ss }
+ .head
+ restore.child match {
+ case node: UnaryExecNode =>
+ assert(node.outputPartitioning.numPartitions === expectedPartition,
+ "Didn't get the expected number of partitions.")
+ if (expectShuffling) {
+ assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got:
${node.child}")
+ } else {
+ assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
+ }
+
+ case _ => fail("Expected no shuffling")
+ }
+ var reachedRestore = false
+ // Check that there should be no exchanges after
`StateStoreRestoreExec`
+ executedPlan.foreachUp { p =>
+ if (reachedRestore) {
+ assert(!p.isInstanceOf[Exchange], "There should be no further
exchanges")
+ } else {
+ reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
+ }
+ }
+ true
+ }
+
+ /** Add blocks of data to the `BlockRDDBackedSource`. */
+ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*)
extends AddData {
+ override def addData(query: Option[StreamExecution]): (Source, Offset)
= {
+ if (data.nonEmpty) {
+ data.foreach(source.addData)
+ } else {
+ // we would like to create empty blockRDD's so add an empty block
here.
+ source.addData()
+ }
+ source.releaseLock()
+ (source, LongOffset(source.counter))
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) with 0 partition RDD should be
repartitioned to 1") {
+ val inputSource = new BlockRDDBackedSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+ val aggregated: Dataset[Long] =
+ spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(1)
+ .groupBy()
+ .count()
+ .as[Long]
+
+ testStream(aggregated, Complete())(
+ AddBlockData(inputSource, Seq(1)),
+ CheckLastBatch(1),
+ AssertOnQuery("Verify no shuffling") { se =>
+ checkAggregationChain(se, expectShuffling = false, 1)
+ },
+ AddBlockData(inputSource), // create an empty trigger
+ CheckLastBatch(1),
+ AssertOnQuery("Verify addition of exchange operator") { se =>
+ checkAggregationChain(se, expectShuffling = true, 1)
+ },
+ AddBlockData(inputSource, Seq(2, 3)),
+ CheckLastBatch(3),
+ AddBlockData(inputSource),
+ CheckLastBatch(3),
+ StopStream
+ )
+ }
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) should still be repartitioned when it has
keyExpressions") {
+ val inputSource = new BlockRDDBackedSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+
+ def createDf(partitions: Int): Dataset[(Long, Long)] = {
+ spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(partitions)
+ .groupBy('a % 1) // just to give it a fake key
+ .count()
+ .as[(Long, Long)]
+ }
+
+ testStream(createDf(1), Complete())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ AddBlockData(inputSource, Seq(1)),
+ CheckLastBatch((0L, 1L)),
+ AssertOnQuery("Verify addition of exchange operator") { se =>
+ checkAggregationChain(
+ se,
+ expectShuffling = true,
+ spark.sessionState.conf.numShufflePartitions)
+ },
+ StopStream
+ )
+
+ testStream(createDf(2), Complete())(
+ StartStream(checkpointLocation = tempDir.getAbsolutePath),
+ Execute(se => se.processAllAvailable()),
+ AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)),
+ CheckLastBatch((0L, 4L)),
+ AssertOnQuery("Verify no exchange added") { se =>
+ checkAggregationChain(
+ se,
+ expectShuffling = false,
+ spark.sessionState.conf.numShufflePartitions)
+ },
+ AddBlockData(inputSource),
+ CheckLastBatch((0L, 4L)),
+ StopStream
+ )
+ }
+ }
+ }
+}
+
+/**
+ * A Streaming Source that is backed by a BlockRDD and that can create
RDDs with 0 blocks at will.
+ */
+class BlockRDDBackedSource(spark: SparkSession) extends Source {
+ var counter = 0L
+ private val blockMgr = SparkEnv.get.blockManager
+ private var blocks: Seq[BlockId] = Seq.empty
+
+ private var streamLock: CountDownLatch = new CountDownLatch(1)
+
+ def addData(data: Int*): Unit = {
+ if (streamLock.getCount == 0) {
+ streamLock = new CountDownLatch(1)
--- End diff --
This is complicated. See how AddFileData is implemented. It's much simpler.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]