Github user tdas commented on a diff in the pull request:
    --- Diff: 
    @@ -381,4 +388,233 @@ class StreamingAggregationSuite extends 
           AddData(streamInput, 0, 1, 2, 3),
           CheckLastBatch((0, 0, 2), (1, 1, 3)))
    +  private def checkAggregationChain(
    +      sq: StreamingQuery,
    +      requiresShuffling: Boolean,
    +      expectedPartition: Int): Unit = {
    +    val executedPlan = 
    +      .lastExecution.executedPlan
    +    val restore = executedPlan
    +      .collect { case ss: StateStoreRestoreExec => ss }
    +      .head
    +    restore.child match {
    +      case node: UnaryExecNode =>
    +        assert(node.outputPartitioning.numPartitions === expectedPartition)
    +        if (requiresShuffling) {
    +          assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: 
    +        } else {
    +          assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
    +        }
    +      case _ => fail("Expected no shuffling")
    +    }
    +    var reachedRestore = false
    +    // Check that there should be no exchanges after 
    +    executedPlan.foreachUp { p =>
    +      if (reachedRestore) {
    +        assert(!p.isInstanceOf[Exchange], "There should be no further 
    +      } else {
    +        reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
    +      }
    +    }
    +  }
    +  test("SPARK-21977: coalesce(1) with 0 partition RDD should be 
repartitioned accordingly") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +        val aggregated: Dataset[Long] =
    +          spark.readStream
    +            .format((new MockSourceProvider).getClass.getCanonicalName)
    +            .load()
    +            .coalesce(1)
    +            .groupBy()
    +            .count()
    +            .as[Long]
    +        val sq = aggregated.writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +        try {
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +          checkAggregationChain(sq, requiresShuffling = false, 1)
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +          checkAggregationChain(sq, requiresShuffling = true, 1)
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +          inputSource.addData(2, 3)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +        } finally {
    +          sq.stop()
    +        }
    +      }
    +    }
    +  }
    +  test("SPARK-21977: coalesce(1) should still be repartitioned when it has 
keyExpressions") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +        val sq = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(1)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +        try {
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +          checkAggregationChain(
    +            sq,
    +            requiresShuffling = true,
    +            spark.sessionState.conf.numShufflePartitions)
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 1L))
    +        } finally {
    +          sq.stop()
    +        }
    +        val sq2 = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(2)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +        try {
    +          sq2.processAllAvailable()
    +          inputSource.addData(2)
    +          inputSource.addData(3)
    +          inputSource.addData(4)
    +          inputSource.releaseLock()
    +          sq2.processAllAvailable()
    +          checkAggregationChain(
    +            sq2,
    +            requiresShuffling = false, // doesn't require extra shuffle as 
HashAggregate adds it
    +            spark.sessionState.conf.numShufflePartitions)
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 4L))
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq2.processAllAvailable()
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 4L))
    +        } finally {
    +          sq2.stop()
    +        }
    +      }
    +    }
    +  }
    + * LocalRelation has some optimized properties during Spark planning. In 
order for the bugs in
    + * SPARK-21977 to occur, we need to create a logical relation from an 
existing RDD. We use a
    + * BlockRDD since it accepts 0 partitions. One requirement for the one of 
the bugs is the use of
    + * `coalesce(1)`, which has several optimizations regarding 
[[SinglePartition]], and a 0 partition
    + * parentRDD.
    + */
    +class NonLocalRelationSource(spark: SparkSession) extends Source {
    --- End diff --
    The docs should explain accordingly, what it does, not why it does it the 
way it is. It really does not matter that local relation is not the right thing 
to use.


To unsubscribe, e-mail:
For additional commands, e-mail:

Reply via email to