Repository: spark
Updated Branches:
  refs/heads/master c6ff59a23 -> 280ff523f


[SPARK-21977] SinglePartition optimizations break certain Streaming Stateful 
Aggregation requirements

## What changes were proposed in this pull request?

This is a bit hard to explain as there are several issues here, I'll try my 
best. Here are the requirements:
  1. A StructuredStreaming Source that can generate empty RDDs with 0 partitions
  2. A StructuredStreaming query that uses the above source, performs a 
stateful aggregation
     (mapGroupsWithState, groupBy.count, ...), and coalesce's by 1

The crux of the problem is that when a dataset has a `coalesce(1)` call, it 
receives a `SinglePartition` partitioning scheme. This scheme satisfies most 
required distributions used for aggregations such as HashAggregateExec. This 
causes a world of problems:
  Symptom 1. If the input RDD has 0 partitions, the whole lineage will receive 
0 partitions, nothing will be executed, the state store will not create any 
delta files. When this happens, the next trigger fails, because the StateStore 
fails to load the delta file for the previous trigger
  Symptom 2. Let's say that there was data. Then in this case, if you stop your 
stream, and change `coalesce(1)` with `coalesce(2)`, then restart your stream, 
your stream will fail, because `spark.sql.shuffle.partitions - 1` number of 
StateStores will fail to find its delta files.

To fix the issues above, we must check that the partitioning of the child of a 
`StatefulOperator` satisfies:
If the grouping expressions are empty:
  a) AllTuple distribution
  b) Single physical partition
If the grouping expressions are non empty:
  a) Clustered distribution
  b) spark.sql.shuffle.partition # of partitions
whether or not `coalesce(1)` exists in the plan, and whether or not the input 
RDD for the trigger has any data.

Once you fix the above problem by adding an Exchange to the plan, you come 
across the following bug:
If you call `coalesce(1).groupBy().count()` on a Streaming DataFrame, and if 
you have a trigger with no data, `StateStoreRestoreExec` doesn't return the 
prior state. However, for this specific aggregation, `HashAggregateExec` after 
the restore returns a (0, 0) row, since we're performing a count, and there is 
no data. Then this data gets stored in `StateStoreSaveExec` causing the 
previous counts to be overwritten and lost.

## How was this patch tested?

Regression tests

Author: Burak Yavuz <brk...@gmail.com>

Closes #19196 from brkyvz/sa-0.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/280ff523
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/280ff523
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/280ff523

Branch: refs/heads/master
Commit: 280ff523f4079dd9541efc95e6efcb69f9374d22
Parents: c6ff59a
Author: Burak Yavuz <brk...@gmail.com>
Authored: Wed Sep 20 00:01:21 2017 -0700
Committer: Burak Yavuz <brk...@gmail.com>
Committed: Wed Sep 20 00:01:21 2017 -0700

----------------------------------------------------------------------
 .../streaming/IncrementalExecution.scala        |  34 +++-
 .../execution/streaming/StreamExecution.scala   |   1 +
 .../execution/streaming/statefulOperators.scala |  37 +++-
 .../EnsureStatefulOpPartitioningSuite.scala     | 132 +++++++++++++
 .../apache/spark/sql/streaming/StreamTest.scala |  16 +-
 .../streaming/StreamingAggregationSuite.scala   | 196 ++++++++++++++++++-
 6 files changed, 395 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 19d9598..027222e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -21,11 +21,13 @@ import java.util.UUID
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, HashPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, 
SparkPlanner, UnaryExecNode}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.streaming.OutputMode
 
 /**
@@ -89,7 +91,7 @@ class IncrementalExecution(
     override def apply(plan: SparkPlan): SparkPlan = plan transform {
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
-               StateStoreRestoreExec(keys2, None, child))) =>
+               StateStoreRestoreExec(_, None, child))) =>
         val aggStateInfo = nextStatefulOperationStateInfo
         StateStoreSaveExec(
           keys,
@@ -117,8 +119,34 @@ class IncrementalExecution(
     }
   }
 
-  override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
+  override def preparations: Seq[Rule[SparkPlan]] =
+    Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations
 
   /** No need assert supported, as this check has already been done */
   override def assertSupported(): Unit = { }
 }
+
+object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
+  // Needs to be transformUp to avoid extra shuffles
+  override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
+    case so: StatefulOperator =>
+      val numPartitions = 
plan.sqlContext.sessionState.conf.numShufflePartitions
+      val distributions = so.requiredChildDistribution
+      val children = so.children.zip(distributions).map { case (child, 
reqDistribution) =>
+        val expectedPartitioning = reqDistribution match {
+          case AllTuples => SinglePartition
+          case ClusteredDistribution(keys) => HashPartitioning(keys, 
numPartitions)
+          case _ => throw new AnalysisException("Unexpected distribution 
expected for " +
+            s"Stateful Operator: $so. Expect AllTuples or 
ClusteredDistribution but got " +
+            s"$reqDistribution.")
+        }
+        if (child.outputPartitioning.guarantees(expectedPartitioning) &&
+            child.execute().getNumPartitions == 
expectedPartitioning.numPartitions) {
+          child
+        } else {
+          ShuffleExchange(expectedPartitioning, child)
+        }
+      }
+      so.withNewChildren(children)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index b27a59b..18385f5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -829,6 +829,7 @@ class StreamExecution(
     if (streamDeathCause != null) {
       throw streamDeathCause
     }
+    if (!isActive) return
     awaitBatchLock.lock()
     try {
       noNewData = false

http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index e463563..d6566b8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, 
Predicate}
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -200,11 +200,20 @@ case class StateStoreRestoreExec(
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
-        iter.flatMap { row =>
-          val key = getKey(row)
-          val savedState = store.get(key)
-          numOutputRows += 1
-          row +: Option(savedState).toSeq
+        val hasInput = iter.hasNext
+        if (!hasInput && keyExpressions.isEmpty) {
+          // If our `keyExpressions` are empty, we're getting a global 
aggregation. In that case
+          // the `HashAggregateExec` will output a 0 value for the partial 
merge. We need to
+          // restore the value, so that we don't overwrite our state with a 0 
value, but rather
+          // merge the 0 with existing state.
+          store.iterator().map(_.value)
+        } else {
+          iter.flatMap { row =>
+            val key = getKey(row)
+            val savedState = store.get(key)
+            numOutputRows += 1
+            row +: Option(savedState).toSeq
+          }
         }
     }
   }
@@ -212,6 +221,14 @@ case class StateStoreRestoreExec(
   override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (keyExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(keyExpressions) :: Nil
+    }
+  }
 }
 
 /**
@@ -351,6 +368,14 @@ case class StateStoreSaveExec(
   override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (keyExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(keyExpressions) :: Nil
+    }
+  }
 }
 
 /** Physical operator for executing streaming Deduplicate. */

http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
new file mode 100644
index 0000000..66c0263
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.UUID
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
+import org.apache.spark.sql.execution.streaming.{IncrementalExecution, 
OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with 
SharedSQLContext {
+
+  import testImplicits._
+  super.beforeAll()
+
+  private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")
+
+  testEnsureStatefulOpPartitioning(
+    "ClusteredDistribution generates Exchange with HashPartitioning",
+    baseDf.queryExecution.sparkPlan,
+    requiredDistribution = keys => ClusteredDistribution(keys),
+    expectedPartitioning =
+      keys => HashPartitioning(keys, 
spark.sessionState.conf.numShufflePartitions),
+    expectShuffle = true)
+
+  testEnsureStatefulOpPartitioning(
+    "ClusteredDistribution with coalesce(1) generates Exchange with 
HashPartitioning",
+    baseDf.coalesce(1).queryExecution.sparkPlan,
+    requiredDistribution = keys => ClusteredDistribution(keys),
+    expectedPartitioning =
+      keys => HashPartitioning(keys, 
spark.sessionState.conf.numShufflePartitions),
+    expectShuffle = true)
+
+  testEnsureStatefulOpPartitioning(
+    "AllTuples generates Exchange with SinglePartition",
+    baseDf.queryExecution.sparkPlan,
+    requiredDistribution = _ => AllTuples,
+    expectedPartitioning = _ => SinglePartition,
+    expectShuffle = true)
+
+  testEnsureStatefulOpPartitioning(
+    "AllTuples with coalesce(1) doesn't need Exchange",
+    baseDf.coalesce(1).queryExecution.sparkPlan,
+    requiredDistribution = _ => AllTuples,
+    expectedPartitioning = _ => SinglePartition,
+    expectShuffle = false)
+
+  /**
+   * For `StatefulOperator` with the given `requiredChildDistribution`, and 
child SparkPlan
+   * `inputPlan`, ensures that the incremental planner adds exchanges, if 
required, in order to
+   * ensure the expected partitioning.
+   */
+  private def testEnsureStatefulOpPartitioning(
+      testName: String,
+      inputPlan: SparkPlan,
+      requiredDistribution: Seq[Attribute] => Distribution,
+      expectedPartitioning: Seq[Attribute] => Partitioning,
+      expectShuffle: Boolean): Unit = {
+    test(testName) {
+      val operator = TestStatefulOperator(inputPlan, 
requiredDistribution(inputPlan.output.take(1)))
+      val executed = executePlan(operator, OutputMode.Complete())
+      if (expectShuffle) {
+        val exchange = executed.children.find(_.isInstanceOf[Exchange])
+        if (exchange.isEmpty) {
+          fail(s"Was expecting an exchange but didn't get one in:\n$executed")
+        }
+        assert(exchange.get ===
+          ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), 
inputPlan),
+          s"Exchange didn't have expected properties:\n${exchange.get}")
+      } else {
+        assert(!executed.children.exists(_.isInstanceOf[Exchange]),
+          s"Unexpected exchange found in:\n$executed")
+      }
+    }
+  }
+
+  /** Executes a SparkPlan using the IncrementalPlanner used for Structured 
Streaming. */
+  private def executePlan(
+      p: SparkPlan,
+      outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
+    val execution = new IncrementalExecution(
+      spark,
+      null,
+      OutputMode.Complete(),
+      "chk",
+      UUID.randomUUID(),
+      0L,
+      OffsetSeqMetadata()) {
+      override lazy val sparkPlan: SparkPlan = p transform {
+        case plan: SparkPlan =>
+          val inputMap = plan.children.flatMap(_.output).map(a => (a.name, 
a)).toMap
+          plan transformExpressions {
+            case UnresolvedAttribute(Seq(u)) =>
+              inputMap.getOrElse(u,
+                sys.error(s"Invalid Test: Cannot resolve $u given input 
$inputMap"))
+          }
+      }
+    }
+    execution.executedPlan
+  }
+}
+
+/** Used to emulate a `StatefulOperator` with the given requiredDistribution. 
*/
+case class TestStatefulOperator(
+    child: SparkPlan,
+    requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
+  override def output: Seq[Attribute] = child.output
+  override def doExecute(): RDD[InternalRow] = child.execute()
+  override def requiredChildDistribution: Seq[Distribution] = requiredDist :: 
Nil
+  override def stateInfo: Option[StatefulOperatorStateInfo] = None
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 4f87640..70b39b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
   case class StartStream(
       trigger: Trigger = Trigger.ProcessingTime(0),
       triggerClock: Clock = new SystemClock,
-      additionalConfs: Map[String, String] = Map.empty)
+      additionalConfs: Map[String, String] = Map.empty,
+      checkpointLocation: String = null)
     extends StreamAction
 
   /** Advance the trigger clock's time manually. */
@@ -349,13 +350,14 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
          """.stripMargin)
     }
 
-    val metadataRoot = Utils.createTempDir(namePrefix = 
"streaming.metadata").getCanonicalPath
     var manualClockExpectedTime = -1L
+    val defaultCheckpointLocation =
+      Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
     try {
       startedTest.foreach { action =>
         logInfo(s"Processing test stream action: $action")
         action match {
-          case StartStream(trigger, triggerClock, additionalConfs) =>
+          case StartStream(trigger, triggerClock, additionalConfs, 
checkpointLocation) =>
             verify(currentStream == null, "stream already running")
             verify(triggerClock.isInstanceOf[SystemClock]
               || triggerClock.isInstanceOf[StreamManualClock],
@@ -363,6 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
             if (triggerClock.isInstanceOf[StreamManualClock]) {
               manualClockExpectedTime = 
triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
             }
+            val metadataRoot = 
Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
 
             additionalConfs.foreach(pair => {
               val value =
@@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
             verify(currentStream != null || lastStream != null,
               "cannot assert when no stream has been started")
             val streamToAssert = Option(currentStream).getOrElse(lastStream)
-            verify(a.condition(streamToAssert), s"Assert on query failed: 
${a.message}")
+            try {
+              verify(a.condition(streamToAssert), s"Assert on query failed: 
${a.message}")
+            } catch {
+              case NonFatal(e) =>
+                failTest(s"Assert on query failed: ${a.message}", e)
+            }
 
           case a: Assert =>
             val streamToAssert = Option(currentStream).getOrElse(lastStream)

http://git-wip-us.apache.org/repos/asf/spark/blob/280ff523/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index e0979ce..995cea3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -22,20 +22,24 @@ import java.util.{Locale, TimeZone}
 import org.scalatest.Assertions
 import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, DataFrame}
+import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.rdd.BlockRDD
+import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, 
SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.exchange.Exchange
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.streaming.OutputMode._
-import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.streaming.util.{MockSourceProvider, 
StreamManualClock}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
 
-object FailureSinglton {
+object FailureSingleton {
   var firstTime = true
 }
 
@@ -226,12 +230,12 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
 
   testQuietly("midbatch failure") {
     val inputData = MemoryStream[Int]
-    FailureSinglton.firstTime = true
+    FailureSingleton.firstTime = true
     val aggregated =
       inputData.toDS()
           .map { i =>
-            if (i == 4 && FailureSinglton.firstTime) {
-              FailureSinglton.firstTime = false
+            if (i == 4 && FailureSingleton.firstTime) {
+              FailureSingleton.firstTime = false
               sys.error("injected failure")
             }
 
@@ -381,4 +385,180 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
       AddData(streamInput, 0, 1, 2, 3),
       CheckLastBatch((0, 0, 2), (1, 1, 3)))
   }
+
+  /**
+   * This method verifies certain properties in the SparkPlan of a streaming 
aggregation.
+   * First of all, it checks that the child of a `StateStoreRestoreExec` 
creates the desired
+   * data distribution, where the child could be an Exchange, or a 
`HashAggregateExec` which already
+   * provides the expected data distribution.
+   *
+   * The second thing it checks that the child provides the expected number of 
partitions.
+   *
+   * The third thing it checks that we don't add an unnecessary shuffle 
in-between
+   * `StateStoreRestoreExec` and `StateStoreSaveExec`.
+   */
+  private def checkAggregationChain(
+      se: StreamExecution,
+      expectShuffling: Boolean,
+      expectedPartition: Int): Boolean = {
+    val executedPlan = se.lastExecution.executedPlan
+    val restore = executedPlan
+      .collect { case ss: StateStoreRestoreExec => ss }
+      .head
+    restore.child match {
+      case node: UnaryExecNode =>
+        assert(node.outputPartitioning.numPartitions === expectedPartition,
+          "Didn't get the expected number of partitions.")
+        if (expectShuffling) {
+          assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: 
${node.child}")
+        } else {
+          assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle")
+        }
+
+      case _ => fail("Expected no shuffling")
+    }
+    var reachedRestore = false
+    // Check that there should be no exchanges after `StateStoreRestoreExec`
+    executedPlan.foreachUp { p =>
+      if (reachedRestore) {
+        assert(!p.isInstanceOf[Exchange], "There should be no further 
exchanges")
+      } else {
+        reachedRestore = p.isInstanceOf[StateStoreRestoreExec]
+      }
+    }
+    true
+  }
+
+  test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned 
to 1") {
+    val inputSource = new BlockRDDBackedSource(spark)
+    MockSourceProvider.withMockSources(inputSource) {
+      // `coalesce(1)` changes the partitioning of data to `SinglePartition` 
which by default
+      // satisfies the required distributions of all aggregations. Therefore 
in our SparkPlan, we
+      // don't have any shuffling. However, `coalesce(1)` only guarantees that 
the RDD has at most 1
+      // partition. Which means that if we have an input RDD with 0 
partitions, nothing gets
+      // executed. Therefore the StateStore's don't save any delta files for a 
given trigger. This
+      // then leads to `FileNotFoundException`s in the subsequent batch.
+      // This isn't the only problem though. Once we introduce a shuffle before
+      // `StateStoreRestoreExec`, the input to the operator is an empty 
iterator. When performing
+      // `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all 
aggregations. If
+      // we fail to restore the previous state in `StateStoreRestoreExec`, we 
save the 0 value in
+      // `StateStoreSaveExec` losing all previous state.
+      val aggregated: Dataset[Long] =
+        spark.readStream.format((new 
MockSourceProvider).getClass.getCanonicalName)
+        .load().coalesce(1).groupBy().count().as[Long]
+
+      testStream(aggregated, Complete())(
+        AddBlockData(inputSource, Seq(1)),
+        CheckLastBatch(1),
+        AssertOnQuery("Verify no shuffling") { se =>
+          checkAggregationChain(se, expectShuffling = false, 1)
+        },
+        AddBlockData(inputSource), // create an empty trigger
+        CheckLastBatch(1),
+        AssertOnQuery("Verify addition of exchange operator") { se =>
+          checkAggregationChain(se, expectShuffling = true, 1)
+        },
+        AddBlockData(inputSource, Seq(2, 3)),
+        CheckLastBatch(3),
+        AddBlockData(inputSource),
+        CheckLastBatch(3),
+        StopStream
+      )
+    }
+  }
+
+  test("SPARK-21977: coalesce(1) with aggregation should still be 
repartitioned when it " +
+    "has non-empty grouping keys") {
+    val inputSource = new BlockRDDBackedSource(spark)
+    MockSourceProvider.withMockSources(inputSource) {
+      withTempDir { tempDir =>
+
+        // `coalesce(1)` changes the partitioning of data to `SinglePartition` 
which by default
+        // satisfies the required distributions of all aggregations. However, 
when we have
+        // non-empty grouping keys, in streaming, we must repartition to
+        // `spark.sql.shuffle.partitions`, otherwise only a single StateStore 
is used to process
+        // all keys. This may be fine, however, if the user removes the 
coalesce(1) or changes to
+        // a `coalesce(2)` for example, then the default behavior is to 
shuffle to
+        // `spark.sql.shuffle.partitions` many StateStores. When this happens, 
all StateStore's
+        // except 1 will be missing their previous delta files, which causes 
the stream to fail
+        // with FileNotFoundException.
+        def createDf(partitions: Int): Dataset[(Long, Long)] = {
+          spark.readStream
+            .format((new MockSourceProvider).getClass.getCanonicalName)
+            .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, 
Long)]
+        }
+
+        testStream(createDf(1), Complete())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddBlockData(inputSource, Seq(1)),
+          CheckLastBatch((0L, 1L)),
+          AssertOnQuery("Verify addition of exchange operator") { se =>
+            checkAggregationChain(
+              se,
+              expectShuffling = true,
+              spark.sessionState.conf.numShufflePartitions)
+          },
+          StopStream
+        )
+
+        testStream(createDf(2), Complete())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          Execute(se => se.processAllAvailable()),
+          AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)),
+          CheckLastBatch((0L, 4L)),
+          AssertOnQuery("Verify no exchange added") { se =>
+            checkAggregationChain(
+              se,
+              expectShuffling = false,
+              spark.sessionState.conf.numShufflePartitions)
+          },
+          AddBlockData(inputSource),
+          CheckLastBatch((0L, 4L)),
+          StopStream
+        )
+      }
+    }
+  }
+
+  /** Add blocks of data to the `BlockRDDBackedSource`. */
+  case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) 
extends AddData {
+    override def addData(query: Option[StreamExecution]): (Source, Offset) = {
+      source.addBlocks(data: _*)
+      (source, LongOffset(source.counter))
+    }
+  }
+
+  /**
+   * A Streaming Source that is backed by a BlockRDD and that can create RDDs 
with 0 blocks at will.
+   */
+  class BlockRDDBackedSource(spark: SparkSession) extends Source {
+    var counter = 0L
+    private val blockMgr = SparkEnv.get.blockManager
+    private var blocks: Seq[BlockId] = Seq.empty
+
+    def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized {
+      dataBlocks.foreach { data =>
+        val id = TestBlockId(counter.toString)
+        blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY)
+        blocks ++= id :: Nil
+        counter += 1
+      }
+      counter += 1
+    }
+
+    override def getOffset: Option[Offset] = synchronized {
+      if (counter == 0) None else Some(LongOffset(counter))
+    }
+
+    override def getBatch(start: Option[Offset], end: Offset): DataFrame = 
synchronized {
+      val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray)
+        .map(i => InternalRow(i)) // we don't really care about the values in 
this test
+      blocks = Seq.empty
+      spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF()
+    }
+    override def schema: StructType = MockSourceProvider.fakeSchema
+    override def stop(): Unit = {
+      
blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to