Repository: spark
Updated Branches:
  refs/heads/master 37a1c0e46 -> 9314c0837


[SPARK-19774] StreamExecution should call stop() on sources when a stream fails

## What changes were proposed in this pull request?

We call stop() on a Structured Streaming Source only when the stream is 
shutdown when a user calls streamingQuery.stop(). We should actually stop all 
sources when the stream fails as well, otherwise we may leak resources, e.g. 
connections to Kafka.

## How was this patch tested?

Unit tests in `StreamingQuerySuite`.

Author: Burak Yavuz <brk...@gmail.com>

Closes #17107 from brkyvz/close-source.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9314c083
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9314c083
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9314c083

Branch: refs/heads/master
Commit: 9314c08377cc8da88f4e31d1a9d41376e96a81b3
Parents: 37a1c0e
Author: Burak Yavuz <brk...@gmail.com>
Authored: Fri Mar 3 10:35:15 2017 -0800
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Fri Mar 3 10:35:15 2017 -0800

----------------------------------------------------------------------
 .../execution/streaming/StreamExecution.scala   | 14 +++-
 .../sql/streaming/StreamingQuerySuite.scala     | 75 +++++++++++++++++-
 .../sql/streaming/util/MockSourceProvider.scala | 83 ++++++++++++++++++++
 3 files changed, 169 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9314c083/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 4bd6431..6e77f35 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -321,6 +321,7 @@ class StreamExecution(
       initializationLatch.countDown()
 
       try {
+        stopSources()
         state.set(TERMINATED)
         currentStatus = status.copy(isTriggerActive = false, isDataAvailable = 
false)
 
@@ -558,6 +559,18 @@ class StreamExecution(
     sparkSession.streams.postListenerEvent(event)
   }
 
+  /** Stops all streaming sources safely. */
+  private def stopSources(): Unit = {
+    uniqueSources.foreach { source =>
+      try {
+        source.stop()
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Failed to stop streaming source: $source. Resources may 
have leaked.", e)
+      }
+    }
+  }
+
   /**
    * Signals to the thread executing micro-batches that it should stop running 
after the next
    * batch. This method blocks until the thread stops running.
@@ -570,7 +583,6 @@ class StreamExecution(
       microBatchThread.interrupt()
       microBatchThread.join()
     }
-    uniqueSources.foreach(_.stop())
     logInfo(s"Query $prettyIdString was stopped")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9314c083/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 1525ad5..a0a2b2b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
 import java.util.concurrent.CountDownLatch
 
 import org.apache.commons.lang3.RandomStringUtils
+import org.mockito.Mockito._
 import org.scalactic.TolerantNumerics
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.mock.MockitoSugar
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -32,11 +34,11 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.util.BlockingSource
+import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
 import org.apache.spark.util.ManualClock
 
 
-class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
+class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging 
with MockitoSugar {
 
   import AwaitTerminationTester._
   import testImplicits._
@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with 
BeforeAndAfter with Logging {
     }
   }
 
+  test("StreamExecution should call stop() on sources when a stream is 
stopped") {
+    var calledStop = false
+    val source = new Source {
+      override def stop(): Unit = {
+        calledStop = true
+      }
+      override def getOffset: Option[Offset] = None
+      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+        spark.emptyDataFrame
+      }
+      override def schema: StructType = MockSourceProvider.fakeSchema
+    }
+
+    MockSourceProvider.withMockSources(source) {
+      val df = spark.readStream
+        .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+        .load()
+
+      testStream(df)(StopStream)
+
+      assert(calledStop, "Did not call stop on source for stopped stream")
+    }
+  }
+
+  testQuietly("SPARK-19774: StreamExecution should call stop() on sources when 
a stream fails") {
+    var calledStop = false
+    val source1 = new Source {
+      override def stop(): Unit = {
+        throw new RuntimeException("Oh no!")
+      }
+      override def getOffset: Option[Offset] = Some(LongOffset(1))
+      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+        spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*)
+      }
+      override def schema: StructType = MockSourceProvider.fakeSchema
+    }
+    val source2 = new Source {
+      override def stop(): Unit = {
+        calledStop = true
+      }
+      override def getOffset: Option[Offset] = None
+      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+        spark.emptyDataFrame
+      }
+      override def schema: StructType = MockSourceProvider.fakeSchema
+    }
+
+    MockSourceProvider.withMockSources(source1, source2) {
+      val df1 = spark.readStream
+        .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+        .load()
+        .as[Int]
+
+      val df2 = spark.readStream
+        .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+        .load()
+        .as[Int]
+
+      testStream(df1.union(df2).map(i => i / 0))(
+        AssertOnQuery { sq =>
+          intercept[StreamingQueryException](sq.processAllAvailable())
+          sq.exception.isDefined && !sq.isActive
+        }
+      )
+
+      assert(calledStop, "Did not call stop on source for stopped stream")
+    }
+  }
+
   /** Create a streaming DF that only execute one batch in which it returns 
the given static DF */
   private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame 
= {
     require(!triggerDF.isStreaming)

http://git-wip-us.apache.org/repos/asf/spark/blob/9314c083/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
new file mode 100644
index 0000000..0bf0538
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.util
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+/**
+ * A StreamSourceProvider that provides mocked Sources for unit testing. 
Example usage:
+ *
+ * {{{
+ *    MockSourceProvider.withMockSources(source1, source2) {
+ *      val df1 = spark.readStream
+ *        .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ *        .load()
+ *
+ *      val df2 = spark.readStream
+ *        .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
+ *        .load()
+ *
+ *      df1.union(df2)
+ *      ...
+ *    }
+ * }}}
+ */
+class MockSourceProvider extends StreamSourceProvider {
+  override def sourceSchema(
+      spark: SQLContext,
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): (String, StructType) = {
+    ("dummySource", MockSourceProvider.fakeSchema)
+  }
+
+  override def createSource(
+      spark: SQLContext,
+      metadataPath: String,
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): Source = {
+    MockSourceProvider.sourceProviderFunction()
+  }
+}
+
+object MockSourceProvider {
+  // Function to generate sources. May provide multiple sources if the user 
implements such a
+  // function.
+  private var sourceProviderFunction: () => Source = _
+
+  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
+
+  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit 
= {
+    var i = 0
+    val sources = source +: otherSources
+    sourceProviderFunction = () => {
+      val source = sources(i % sources.length)
+      i += 1
+      source
+    }
+    try {
+      f
+    } finally {
+      sourceProviderFunction = null
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to