Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19196#discussion_r138761881
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
---
@@ -381,4 +388,233 @@ class StreamingAggregationSuite extends
StateStoreMetricsTest
AddData(streamInput, 0, 1, 2, 3),
CheckLastBatch((0, 0, 2), (1, 1, 3)))
}
+
+ private def checkAggregationChain(
+ sq: StreamingQuery,
+ requiresShuffling: Boolean,
+ expectedPartition: Int): Unit = {
+ val executedPlan =
sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
+ .lastExecution.executedPlan
+ val restore = executedPlan
+ .collect { case ss: StateStoreRestoreExec => ss }
+ .head
+ restore.child match {
+ case node: UnaryExecNode =>
+ assert(node.outputPartitioning.numPartitions === expectedPartition)
+ if (requiresShuffling) {
+ assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got:
${node.child}")
+ } else {
+ assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
+ }
+
+ case _ => fail("Expected no shuffling")
+ }
+ var reachedRestore = false
+ // Check that there should be no exchanges after
`StateStoreRestoreExec`
+ executedPlan.foreachUp { p =>
+ if (reachedRestore) {
+ assert(!p.isInstanceOf[Exchange], "There should be no further
exchanges")
+ } else {
+ reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
+ }
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) with 0 partition RDD should be
repartitioned accordingly") {
+ val inputSource = new NonLocalRelationSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+ val aggregated: Dataset[Long] =
+ spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(1)
+ .groupBy()
+ .count()
+ .as[Long]
+
+ val sq = aggregated.writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+
+ inputSource.addData(1)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 1L)
+
+ checkAggregationChain(sq, requiresShuffling = false, 1)
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkAggregationChain(sq, requiresShuffling = true, 1)
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 1L)
+
+ inputSource.addData(2, 3)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 3L)
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 3L)
+ } finally {
+ sq.stop()
+ }
+ }
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) should still be repartitioned when it has
keyExpressions") {
+ val inputSource = new NonLocalRelationSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+
+ val sq = spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(1)
+ .groupBy('a % 1) // just to give it a fake key
+ .count()
+ .as[(Long, Long)]
+ .writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+
+ inputSource.addData(1)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkAggregationChain(
+ sq,
+ requiresShuffling = true,
+ spark.sessionState.conf.numShufflePartitions)
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 1L))
+
+ } finally {
+ sq.stop()
+ }
+
+ val sq2 = spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(2)
+ .groupBy('a % 1) // just to give it a fake key
+ .count()
+ .as[(Long, Long)]
+ .writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+ sq2.processAllAvailable()
+ inputSource.addData(2)
+ inputSource.addData(3)
+ inputSource.addData(4)
+ inputSource.releaseLock()
+ sq2.processAllAvailable()
+
+ checkAggregationChain(
+ sq2,
+ requiresShuffling = false, // doesn't require extra shuffle as
HashAggregate adds it
+ spark.sessionState.conf.numShufflePartitions)
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 4L))
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq2.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 4L))
+ } finally {
+ sq2.stop()
+ }
+ }
+ }
+ }
+}
+
+/**
+ * LocalRelation has some optimized properties during Spark planning. In
order for the bugs in
+ * SPARK-21977 to occur, we need to create a logical relation from an
existing RDD. We use a
+ * BlockRDD since it accepts 0 partitions. One requirement for the one of
the bugs is the use of
+ * `coalesce(1)`, which has several optimizations regarding
[[SinglePartition]], and a 0 partition
+ * parentRDD.
+ */
+class NonLocalRelationSource(spark: SparkSession) extends Source {
--- End diff --
The point of this source is to basically create empty batches,
local/non-local are just internal details. So it should be named accordingly.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]