Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19196#discussion_r138762394
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 ---
    @@ -381,4 +388,233 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
           AddData(streamInput, 0, 1, 2, 3),
           CheckLastBatch((0, 0, 2), (1, 1, 3)))
       }
    +
    +  private def checkAggregationChain(
    +      sq: StreamingQuery,
    +      requiresShuffling: Boolean,
    +      expectedPartition: Int): Unit = {
    +    val executedPlan = 
sq.asInstanceOf[StreamingQueryWrapper].streamingQuery
    +      .lastExecution.executedPlan
    +    val restore = executedPlan
    +      .collect { case ss: StateStoreRestoreExec => ss }
    +      .head
    +    restore.child match {
    +      case node: UnaryExecNode =>
    +        assert(node.outputPartitioning.numPartitions === expectedPartition)
    +        if (requiresShuffling) {
    +          assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: 
${node.child}")
    +        } else {
    +          assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
    +        }
    +
    +      case _ => fail("Expected no shuffling")
    +    }
    +    var reachedRestore = false
    +    // Check that there should be no exchanges after 
`StateStoreRestoreExec`
    +    executedPlan.foreachUp { p =>
    +      if (reachedRestore) {
    +        assert(!p.isInstanceOf[Exchange], "There should be no further 
exchanges")
    +      } else {
    +        reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
    +      }
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) with 0 partition RDD should be 
repartitioned accordingly") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +        val aggregated: Dataset[Long] =
    +          spark.readStream
    +            .format((new MockSourceProvider).getClass.getCanonicalName)
    +            .load()
    +            .coalesce(1)
    +            .groupBy()
    +            .count()
    +            .as[Long]
    +
    +        val sq = aggregated.writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +
    +        try {
    +
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +
    +          checkAggregationChain(sq, requiresShuffling = false, 1)
    +
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkAggregationChain(sq, requiresShuffling = true, 1)
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            1L)
    +
    +          inputSource.addData(2, 3)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[Long],
    +            3L)
    +        } finally {
    +          sq.stop()
    +        }
    +      }
    +    }
    +  }
    +
    +  test("SPARK-21977: coalesce(1) should still be repartitioned when it has 
keyExpressions") {
    +    val inputSource = new NonLocalRelationSource(spark)
    +    MockSourceProvider.withMockSources(inputSource) {
    +      withTempDir { tempDir =>
    +
    +        val sq = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(1)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +
    +        try {
    +
    +          inputSource.addData(1)
    +          inputSource.releaseLock()
    +          sq.processAllAvailable()
    +
    +          checkAggregationChain(
    +            sq,
    +            requiresShuffling = true,
    +            spark.sessionState.conf.numShufflePartitions)
    +
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 1L))
    +
    +        } finally {
    +          sq.stop()
    +        }
    +
    +        val sq2 = spark.readStream
    +          .format((new MockSourceProvider).getClass.getCanonicalName)
    +          .load()
    +          .coalesce(2)
    +          .groupBy('a % 1) // just to give it a fake key
    +          .count()
    +          .as[(Long, Long)]
    +          .writeStream
    +          .format("memory")
    +          .outputMode("complete")
    +          .queryName("agg_test")
    +          .option("checkpointLocation", tempDir.getAbsolutePath)
    +          .start()
    +
    +        try {
    +          sq2.processAllAvailable()
    +          inputSource.addData(2)
    +          inputSource.addData(3)
    +          inputSource.addData(4)
    +          inputSource.releaseLock()
    +          sq2.processAllAvailable()
    +
    +          checkAggregationChain(
    +            sq2,
    +            requiresShuffling = false, // doesn't require extra shuffle as 
HashAggregate adds it
    +            spark.sessionState.conf.numShufflePartitions)
    +
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 4L))
    +
    +          inputSource.addData()
    +          inputSource.releaseLock()
    +          sq2.processAllAvailable()
    +
    +          checkDataset(
    +            spark.table("agg_test").as[(Long, Long)],
    +            (0L, 4L))
    +        } finally {
    +          sq2.stop()
    +        }
    +      }
    +    }
    +  }
    +}
    +
    +/**
    + * LocalRelation has some optimized properties during Spark planning. In 
order for the bugs in
    + * SPARK-21977 to occur, we need to create a logical relation from an 
existing RDD. We use a
    + * BlockRDD since it accepts 0 partitions. One requirement for the one of 
the bugs is the use of
    + * `coalesce(1)`, which has several optimizations regarding 
[[SinglePartition]], and a 0 partition
    + * parentRDD.
    + */
    +class NonLocalRelationSource(spark: SparkSession) extends Source {
    --- End diff --
    
    The docs should explain accordingly, what it does, not why it does it the 
way it is. It really does not matter that local relation is not the right thing 
to use.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to