Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19196#discussion_r138762394
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
---
@@ -381,4 +388,233 @@ class StreamingAggregationSuite extends
StateStoreMetricsTest
AddData(streamInput, 0, 1, 2, 3),
CheckLastBatch((0, 0, 2), (1, 1, 3)))
}
+
+ private def checkAggregationChain(
+ sq: StreamingQuery,
+ requiresShuffling: Boolean,
+ expectedPartition: Int): Unit = {
+ val executedPlan =
sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
+ .lastExecution.executedPlan
+ val restore = executedPlan
+ .collect { case ss: StateStoreRestoreExec => ss }
+ .head
+ restore.child match {
+ case node: UnaryExecNode =>
+ assert(node.outputPartitioning.numPartitions === expectedPartition)
+ if (requiresShuffling) {
+ assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got:
${node.child}")
+ } else {
+ assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
+ }
+
+ case _ => fail("Expected no shuffling")
+ }
+ var reachedRestore = false
+ // Check that there should be no exchanges after
`StateStoreRestoreExec`
+ executedPlan.foreachUp { p =>
+ if (reachedRestore) {
+ assert(!p.isInstanceOf[Exchange], "There should be no further
exchanges")
+ } else {
+ reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
+ }
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) with 0 partition RDD should be
repartitioned accordingly") {
+ val inputSource = new NonLocalRelationSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+ val aggregated: Dataset[Long] =
+ spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(1)
+ .groupBy()
+ .count()
+ .as[Long]
+
+ val sq = aggregated.writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+
+ inputSource.addData(1)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 1L)
+
+ checkAggregationChain(sq, requiresShuffling = false, 1)
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkAggregationChain(sq, requiresShuffling = true, 1)
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 1L)
+
+ inputSource.addData(2, 3)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 3L)
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[Long],
+ 3L)
+ } finally {
+ sq.stop()
+ }
+ }
+ }
+ }
+
+ test("SPARK-21977: coalesce(1) should still be repartitioned when it has
keyExpressions") {
+ val inputSource = new NonLocalRelationSource(spark)
+ MockSourceProvider.withMockSources(inputSource) {
+ withTempDir { tempDir =>
+
+ val sq = spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(1)
+ .groupBy('a % 1) // just to give it a fake key
+ .count()
+ .as[(Long, Long)]
+ .writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+
+ inputSource.addData(1)
+ inputSource.releaseLock()
+ sq.processAllAvailable()
+
+ checkAggregationChain(
+ sq,
+ requiresShuffling = true,
+ spark.sessionState.conf.numShufflePartitions)
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 1L))
+
+ } finally {
+ sq.stop()
+ }
+
+ val sq2 = spark.readStream
+ .format((new MockSourceProvider).getClass.getCanonicalName)
+ .load()
+ .coalesce(2)
+ .groupBy('a % 1) // just to give it a fake key
+ .count()
+ .as[(Long, Long)]
+ .writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("agg_test")
+ .option("checkpointLocation", tempDir.getAbsolutePath)
+ .start()
+
+ try {
+ sq2.processAllAvailable()
+ inputSource.addData(2)
+ inputSource.addData(3)
+ inputSource.addData(4)
+ inputSource.releaseLock()
+ sq2.processAllAvailable()
+
+ checkAggregationChain(
+ sq2,
+ requiresShuffling = false, // doesn't require extra shuffle as
HashAggregate adds it
+ spark.sessionState.conf.numShufflePartitions)
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 4L))
+
+ inputSource.addData()
+ inputSource.releaseLock()
+ sq2.processAllAvailable()
+
+ checkDataset(
+ spark.table("agg_test").as[(Long, Long)],
+ (0L, 4L))
+ } finally {
+ sq2.stop()
+ }
+ }
+ }
+ }
+}
+
+/**
+ * LocalRelation has some optimized properties during Spark planning. In
order for the bugs in
+ * SPARK-21977 to occur, we need to create a logical relation from an
existing RDD. We use a
+ * BlockRDD since it accepts 0 partitions. One requirement for the one of
the bugs is the use of
+ * `coalesce(1)`, which has several optimizations regarding
[[SinglePartition]], and a 0 partition
+ * parentRDD.
+ */
+class NonLocalRelationSource(spark: SparkSession) extends Source {
--- End diff --
The docs should explain accordingly, what it does, not why it does it the
way it is. It really does not matter that local relation is not the right thing
to use.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]