This is an automated email from the ASF dual-hosted git repository. vanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b62ef8f [SPARK-29007][STREAMING][MLLIB][TESTS] Enforce not leaking SparkContext in tests which creates new StreamingContext with new SparkContext b62ef8f is described below commit b62ef8f7935ae5c9a4a5e7e8a17aa5d7375c85b1 Author: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com> AuthorDate: Wed Sep 11 10:29:13 2019 -0700 [SPARK-29007][STREAMING][MLLIB][TESTS] Enforce not leaking SparkContext in tests which creates new StreamingContext with new SparkContext ### What changes were proposed in this pull request? This patch enforces tests to prevent leaking newly created SparkContext while is created via initializing StreamingContext. Leaking SparkContext in test would make most of following tests being failed as well, so this patch applies defensive programming, trying its best to ensure SparkContext is cleaned up. ### Why are the changes needed? We got some case in CI build where SparkContext is being leaked and other tests are affected by leaked SparkContext. Ideally we should isolate the environment among tests if possible. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Modified UTs. Closes #25709 from HeartSaVioR/SPARK-29007. Authored-by: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com> Signed-off-by: Marcelo Vanzin <van...@cloudera.com> --- external/kafka-0-10/pom.xml | 7 ++ .../kafka010/DirectKafkaStreamSuite.scala | 24 +++---- .../streaming/kinesis/KinesisStreamSuite.scala | 37 ++++------ mllib/pom.xml | 7 ++ .../StreamingLogisticRegressionSuite.scala | 16 ++--- .../mllib/clustering/StreamingKMeansSuite.scala | 13 +--- .../StreamingLinearRegressionSuite.scala | 16 ++--- .../apache/spark/streaming/CheckpointSuite.scala | 18 ++--- .../spark/streaming/DStreamClosureSuite.scala | 16 +---- .../apache/spark/streaming/DStreamScopeSuite.scala | 23 +++--- .../spark/streaming/LocalStreamingContext.scala | 83 ++++++++++++++++++++++ .../apache/spark/streaming/MapWithStateSuite.scala | 30 +++----- .../streaming/ReceiverInputDStreamSuite.scala | 16 ++--- .../spark/streaming/StreamingContextSuite.scala | 52 ++++++-------- .../spark/streaming/StreamingListenerSuite.scala | 11 +-- .../org/apache/spark/streaming/TestSuiteBase.scala | 30 ++++---- .../spark/streaming/WindowOperationsSuite.scala | 19 ++--- .../scheduler/ExecutorAllocationManagerSuite.scala | 19 ++--- .../scheduler/InputInfoTrackerSuite.scala | 22 ++---- .../streaming/scheduler/RateControllerSuite.scala | 6 +- .../ui/StreamingJobProgressListenerSuite.scala | 16 ++--- 21 files changed, 240 insertions(+), 241 deletions(-) diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 397de87..d11569d 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -47,6 +47,13 @@ </dependency> <dependency> <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> <type>test-jar</type> diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 4d3e476..26b41e6 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.streaming.kafka010 import java.io.File -import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } +import java.lang.{Long => JLong} +import java.util.{Arrays, HashMap => JHashMap, Map => JMap, UUID} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -31,13 +31,12 @@ import scala.util.Random import org.apache.kafka.clients.consumer._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.{LocalStreamingContext, Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator @@ -45,8 +44,7 @@ import org.apache.spark.util.Utils class DirectKafkaStreamSuite extends SparkFunSuite - with BeforeAndAfter - with BeforeAndAfterAll + with LocalStreamingContext with Eventually with Logging { val sparkConf = new SparkConf() @@ -56,7 +54,6 @@ class DirectKafkaStreamSuite // Otherwise the poll timeout defaults to 2 minutes and causes test cases to run longer. .set("spark.streaming.kafka.consumer.poll.ms", "10000") - private var ssc: StreamingContext = _ private var testDir: File = _ private var kafkaTestUtils: KafkaTestUtils = _ @@ -78,12 +75,13 @@ class DirectKafkaStreamSuite } } - after { - if (ssc != null) { - ssc.stop(stopSparkContext = true) - } - if (testDir != null) { - Utils.deleteRecursively(testDir) + override def afterEach(): Unit = { + try { + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } finally { + super.afterEach() } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ee53fba..eee62d2 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{LocalStreamingContext, _} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.kinesis.KinesisInitialPositions.Latest import org.apache.spark.streaming.kinesis.KinesisReadConfigurations._ @@ -40,7 +40,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite - with Eventually with BeforeAndAfter with BeforeAndAfterAll { + with LocalStreamingContext with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" @@ -53,15 +53,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun private val dummyAWSSecretKey = "dummySecretKey" private var testUtils: KinesisTestUtils = null - private var ssc: StreamingContext = null private var sc: SparkContext = null override def beforeAll(): Unit = { - val conf = new SparkConf() - .setMaster("local[4]") - .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name - sc = new SparkContext(conf) - runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() @@ -70,12 +64,6 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun override def afterAll(): Unit = { try { - if (ssc != null) { - ssc.stop() - } - if (sc != null) { - sc.stop() - } if (testUtils != null) { // Delete the Kinesis stream as well as the DynamoDB table generated by // Kinesis Client Library when consuming the stream @@ -87,17 +75,22 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } } - before { + override def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) ssc = new StreamingContext(sc, batchDuration) } - after { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - ssc = null - } - if (testUtils != null) { - testUtils.deleteDynamoDBTable(appName) + override def afterEach(): Unit = { + try { + if (testUtils != null) { + testUtils.deleteDynamoDBTable(appName) + } + } finally { + super.afterEach() } } diff --git a/mllib/pom.xml b/mllib/pom.xml index 11769ef..2d21196 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,6 +57,13 @@ </dependency> <dependency> <groupId>org.apache.spark</groupId> + <artifactId>spark-streaming_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> <artifactId>spark-sql_${scala.binary.version}</artifactId> <version>${project.version}</version> </dependency> diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5f797a6..7349e03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -23,23 +23,17 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite + extends SparkFunSuite + with LocalStreamingContext + with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index a1ac10c..415ac87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,23 +20,14 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with LocalStreamingContext with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index eaeaa3fc..5b94f7e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,23 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{LocalStreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite + extends SparkFunSuite + with LocalStreamingContext + with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6072957..e6f4f04 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -39,8 +39,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, - Utils} +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, Utils} /** * A input stream that records the times of restore() invoked @@ -206,24 +205,21 @@ trait DStreamCheckpointTester { self: SparkFunSuite => * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester +class CheckpointSuite extends TestSuiteBase with LocalStreamingContext with DStreamCheckpointTester with ResetSystemProperties { - var ssc: StreamingContext = null - override def batchDuration: Duration = Milliseconds(500) - override def beforeFunction() { - super.beforeFunction() + override def beforeEach(): Unit = { + super.beforeEach() Utils.deleteRecursively(new File(checkpointDir)) } - override def afterFunction() { + override def afterEach(): Unit = { try { - if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } finally { - super.afterFunction() + super.afterEach() } } @@ -255,7 +251,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester .checkpoint(stateStreamCheckpointInterval) .map(t => (t._1, t._2)) } - var ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head def waitForCompletionOfBatch(numBatches: Long): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 2ab600a..0576bf5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -29,24 +29,14 @@ import org.apache.spark.util.ReturnStatementInClosureException /** * Test that closures passed to DStream operations are actually cleaned. */ -class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { - private var ssc: StreamingContext = null +class DStreamClosureSuite extends SparkFunSuite with LocalStreamingContext with BeforeAndAfterAll { + override protected def beforeEach(): Unit = { + super.beforeEach() - override def beforeAll(): Unit = { - super.beforeAll() val sc = new SparkContext("local", "test") ssc = new StreamingContext(sc, Seconds(1)) } - override def afterAll(): Unit = { - try { - ssc.stop(stopSparkContext = true) - ssc = null - } finally { - super.afterAll() - } - } - test("user provided closures are actually cleaned") { val dstream = new DummyInputDStream(ssc) val pairDstream = dstream.map { i => (i, i) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 94f1bce..1bb4116 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -30,28 +30,29 @@ import org.apache.spark.util.ManualClock /** * Tests whether scope information is passed from DStream operations to RDDs correctly. */ -class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { - private var ssc: StreamingContext = null - private val batchDuration: Duration = Seconds(1) +class DStreamScopeSuite + extends SparkFunSuite + with LocalStreamingContext { + + override def beforeEach(): Unit = { + super.beforeEach() - override def beforeAll(): Unit = { - super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("test") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + val batchDuration: Duration = Seconds(1) ssc = new StreamingContext(new SparkContext(conf), batchDuration) + + assertPropertiesNotSet() } - override def afterAll(): Unit = { + override def afterEach(): Unit = { try { - ssc.stop(stopSparkContext = true) + assertPropertiesNotSet() } finally { - super.afterAll() + super.afterEach() } } - before { assertPropertiesNotSet() } - after { assertPropertiesNotSet() } - test("dstream without scope") { val dummyStream = new DummyDStream(ssc) dummyStream.initialize(Time(0)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/LocalStreamingContext.scala b/streaming/src/test/scala/org/apache/spark/streaming/LocalStreamingContext.scala new file mode 100644 index 0000000..2008c1c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/LocalStreamingContext.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{BeforeAndAfterEach, Suite} + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging + +/** + * Manages a local `ssc` `StreamingContext` variable, correctly stopping it after each test. + * Note that it also stops active SparkContext if `stopSparkContext` is set to true (default). + * In most cases you may want to leave it, to isolate environment for SparkContext in each test. + */ +trait LocalStreamingContext extends BeforeAndAfterEach { self: Suite => + + @transient var ssc: StreamingContext = _ + @transient var stopSparkContext: Boolean = true + + override def afterEach() { + try { + resetStreamingContext() + } finally { + super.afterEach() + } + } + + def resetStreamingContext(): Unit = { + LocalStreamingContext.stop(ssc, stopSparkContext) + ssc = null + } +} + +object LocalStreamingContext extends Logging { + def stop(ssc: StreamingContext, stopSparkContext: Boolean): Unit = { + try { + if (ssc != null) { + ssc.stop(stopSparkContext = stopSparkContext) + } + } finally { + if (stopSparkContext) { + ensureNoActiveSparkContext() + } + } + } + + /** + * Clean up active SparkContext: try to stop first if there's an active SparkContext. + * If it fails to stop, log warning message and clear active SparkContext to avoid + * interfere between tests. + */ + def ensureNoActiveSparkContext(): Unit = { + // if SparkContext is still active, try to clean up + SparkContext.getActive match { + case Some(sc) => + try { + sc.stop() + } catch { + case e: Throwable => + logError("Exception trying to stop SparkContext, clear active SparkContext...", e) + SparkContext.clearActiveContext() + throw e + } + case _ => + } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 06c0c2a..14796c4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -23,46 +23,36 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.PrivateMethodTester._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} -class MapWithStateSuite extends SparkFunSuite - with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { +class MapWithStateSuite extends SparkFunSuite with LocalStreamingContext + with DStreamCheckpointTester { private var sc: SparkContext = null protected var checkpointDir: File = null protected val batchDuration = Seconds(1) - before { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") - } + override def beforeEach(): Unit = { + super.beforeEach() - after { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - if (checkpointDir != null) { - Utils.deleteRecursively(checkpointDir) - } - } - - override def beforeAll(): Unit = { - super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) sc = new SparkContext(conf) + + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } - override def afterAll(): Unit = { + override def afterEach(): Unit = { try { - if (sc != null) { - sc.stop() + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } } finally { - super.afterAll() + super.afterEach() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index 0349e11..5e2ce25 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming import scala.util.Random -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} @@ -30,15 +28,9 @@ import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiv import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} -class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { - - override def afterAll(): Unit = { - try { - StreamingContext.getActive().foreach(_.stop()) - } finally { - super.afterAll() - } - } +class ReceiverInputDStreamSuite + extends TestSuiteBase + with LocalStreamingContext { testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) @@ -127,7 +119,7 @@ class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) - val ssc = new StreamingContext(conf, Seconds(1)) + ssc = new StreamingContext(conf, Seconds(1)) val receiverStream = new ReceiverInputDStream[Int](ssc) { override def getReceiver(): Receiver[Int] = null } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index c4424b3..ea4c1d4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Queue import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} +import org.scalatest.{Assertions, PrivateMethodTester} import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ import org.scalatest.exceptions.TestFailedDueToTimeoutException @@ -44,7 +44,11 @@ import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.{ManualClock, Utils} -class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits with Logging { +class StreamingContextSuite + extends SparkFunSuite + with LocalStreamingContext + with TimeLimits + with Logging { // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val signaler: Signaler = ThreadSignaler @@ -56,20 +60,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL val envPair = "key" -> "value" val conf = new SparkConf().setMaster(master).setAppName(appName) - var sc: SparkContext = null - var ssc: StreamingContext = null - - after { - if (ssc != null) { - ssc.stop() - ssc = null - } - if (sc != null) { - sc.stop() - sc = null - } - } - test("from no conf constructor") { ssc = new StreamingContext(master, appName, batchDuration) assert(ssc.sparkContext.conf.get("spark.master") === master) @@ -95,7 +85,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL } test("from existing SparkContext") { - sc = new SparkContext(master, appName) + val sc = new SparkContext(master, appName) ssc = new StreamingContext(sc, batchDuration) } @@ -272,7 +262,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // Explicitly do not stop SparkContext ssc = new StreamingContext(conf, batchDuration) - sc = ssc.sparkContext + var sc = ssc.sparkContext addInputStream(ssc).register() ssc.start() ssc.stop(stopSparkContext = false) @@ -306,7 +296,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.dummyTimeConfig", "3600s") - sc = new SparkContext(conf) + val sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) @@ -338,7 +328,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // This is not a deterministic unit. But if this unit test is flaky, then there is definitely // something wrong. See SPARK-5681 val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) + val sc = new SparkContext(conf) ssc = new StreamingContext(sc, Milliseconds(100)) val input = ssc.receiverStream(new TestReceiver) input.foreachRDD(_ => {}) @@ -352,7 +342,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.streaming.gracefulStopTimeout", "20000s") - sc = new SparkContext(conf) + val sc = new SparkContext(conf) logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 @@ -591,7 +581,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getOrCreate should recover StreamingContext with existing SparkContext testGetOrCreate { - sc = new SparkContext(conf) + val sc = new SparkContext(conf) ssc = StreamingContext.getOrCreate(checkpointPath, () => creatingFunction()) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") @@ -603,7 +593,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL require(StreamingContext.getActive().isEmpty, "context exists from before") var newContextCreated = false - def creatingFunc(): StreamingContext = { + def creatingFunc(sc: SparkContext)(): StreamingContext = { newContextCreated = true val newSsc = new StreamingContext(sc, batchDuration) val input = addInputStream(newSsc) @@ -627,8 +617,8 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getActiveOrCreate should create new context and getActive should return it only // after starting the context testGetActiveOrCreate { - sc = new SparkContext(conf) - ssc = StreamingContext.getActiveOrCreate(creatingFunc _) + val sc = new SparkContext(conf) + ssc = StreamingContext.getActiveOrCreate(creatingFunc(sc)) assert(ssc != null, "no context created") assert(newContextCreated, "new context not created") assert(StreamingContext.getActive().isEmpty, @@ -636,25 +626,25 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL ssc.start() assert(StreamingContext.getActive() === Some(ssc), "active context not returned") - assert(StreamingContext.getActiveOrCreate(creatingFunc _) === ssc, + assert(StreamingContext.getActiveOrCreate(creatingFunc(sc)) === ssc, "active context not returned") ssc.stop() assert(StreamingContext.getActive().isEmpty, "inactive context returned") - assert(StreamingContext.getActiveOrCreate(creatingFunc _) !== ssc, + assert(StreamingContext.getActiveOrCreate(creatingFunc(sc)) !== ssc, "inactive context returned") } // getActiveOrCreate and getActive should return independently created context after activating testGetActiveOrCreate { - sc = new SparkContext(conf) - ssc = creatingFunc() // Create + val sc = new SparkContext(conf) + ssc = creatingFunc(sc) // Create assert(StreamingContext.getActive().isEmpty, "new initialized context returned before starting") ssc.start() assert(StreamingContext.getActive() === Some(ssc), "active context not returned") - assert(StreamingContext.getActiveOrCreate(creatingFunc _) === ssc, + assert(StreamingContext.getActiveOrCreate(creatingFunc(sc)) === ssc, "active context not returned") ssc.stop() assert(StreamingContext.getActive().isEmpty, @@ -736,7 +726,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL } test("multiple streaming contexts") { - sc = new SparkContext( + val sc = new SparkContext( conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")) ssc = new StreamingContext(sc, Seconds(1)) val input = addInputStream(ssc) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 62fd433..9cd5d8c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,20 +36,11 @@ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -class StreamingListenerSuite extends TestSuiteBase with Matchers { +class StreamingListenerSuite extends TestSuiteBase with LocalStreamingContext with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index f2ae778..6f0475c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterEach import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.time.{Seconds => ScalaTestSeconds, Span} @@ -211,7 +211,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfterEach with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName @@ -250,7 +250,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val eventuallyTimeout: PatienceConfiguration.Timeout = timeout(Span(10, ScalaTestSeconds)) // Default before function for any streaming test suite. Override this - // if you want to add your stuff to "before" (i.e., don't call before { } ) + // if you want to add your stuff to "beforeEach" def beforeFunction() { if (useManualClock) { logInfo("Using manual clock") @@ -262,13 +262,24 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } // Default after function for any streaming test suite. Override this - // if you want to add your stuff to "after" (i.e., don't call after { } ) + // if you want to add your stuff to "afterEach" def afterFunction() { System.clearProperty("spark.streaming.clock") } - before(beforeFunction) - after(afterFunction) + override def beforeEach(): Unit = { + super.beforeEach() + beforeFunction() + } + + override def afterEach(): Unit = { + try { + afterFunction() + } finally { + super.afterEach() + } + + } /** * Run a block of code with the given StreamingContext and automatically @@ -278,12 +289,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { try { block(ssc) } finally { - try { - ssc.stop(stopSparkContext = true) - } catch { - case e: Exception => - logError("Error stopping StreamingContext", e) - } + LocalStreamingContext.stop(ssc, stopSparkContext = true) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c7d085e..f580b49 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -146,15 +146,16 @@ class WindowOperationsSuite extends TestSuiteBase { test("window - persistence level") { val input = Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)) - val ssc = new StreamingContext(conf, batchDuration) - val inputStream = new TestInputStream[Int](ssc, input, 1) - val windowStream1 = inputStream.window(batchDuration * 2) - assert(windowStream1.storageLevel === StorageLevel.NONE) - assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER) - windowStream1.persist(StorageLevel.MEMORY_ONLY) - assert(windowStream1.storageLevel === StorageLevel.NONE) - assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY) - ssc.stop() + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val inputStream = new TestInputStream[Int](ssc, input, 1) + val windowStream1 = inputStream.window(batchDuration * 2) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER) + windowStream1.persist(StorageLevel.MEMORY_ONLY) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY) + } } // Testing naive reduceByKeyAndWindow (without invertible function) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index a8b0055..a3026b2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -19,26 +19,25 @@ package org.apache.spark.streaming.scheduler import org.mockito.ArgumentMatchers.{eq => meq} import org.mockito.Mockito.{never, reset, times, verify, when} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually.{eventually, timeout} import org.scalatest.mockito.MockitoSugar import org.scalatest.time.SpanSugar._ -import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite} +import org.apache.spark.{ExecutorAllocationClient, SparkConf} import org.apache.spark.internal.config.{DYN_ALLOCATION_ENABLED, DYN_ALLOCATION_TESTING} import org.apache.spark.internal.config.Streaming._ -import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext} +import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext, TestSuiteBase} import org.apache.spark.util.{ManualClock, Utils} - -class ExecutorAllocationManagerSuite extends SparkFunSuite - with BeforeAndAfter with BeforeAndAfterAll with MockitoSugar with PrivateMethodTester { +class ExecutorAllocationManagerSuite extends TestSuiteBase + with MockitoSugar with PrivateMethodTester { private val batchDurationMillis = 1000L private var allocationClient: ExecutorAllocationClient = null private var clock: StreamManualClock = null - before { + override def beforeEach(): Unit = { allocationClient = mock[ExecutorAllocationClient] clock = new StreamManualClock() } @@ -392,13 +391,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite .setAppName(this.getClass.getSimpleName) .set("spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation - var ssc: StreamingContext = null - try { - ssc = new StreamingContext(conf, Seconds(1)) + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => new DummyInputDStream(ssc).foreachRDD(_ => { }) body(ssc) - } finally { - if (ssc != null) ssc.stop() } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index a7e3656..cc39342 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,27 +17,15 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.BeforeAndAfter - import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.streaming.{Duration, StreamingContext, Time} - -class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { +import org.apache.spark.streaming.{Duration, LocalStreamingContext, StreamingContext, Time} - private var ssc: StreamingContext = _ +class InputInfoTrackerSuite extends SparkFunSuite with LocalStreamingContext { - before { + override def beforeEach(): Unit = { + super.beforeEach() val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") - if (ssc == null) { - ssc = new StreamingContext(conf, Duration(1000)) - } - } - - after { - if (ssc != null) { - ssc.stop() - ssc = null - } + ssc = new StreamingContext(conf, Duration(1000)) } test("test report and get InputInfo from InputInfoTracker") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 37ca0ce..b5a45fc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -30,8 +30,7 @@ class RateControllerSuite extends TestSuiteBase { override def batchDuration: Duration = Milliseconds(50) test("RateController - rate controller publishes updates after batches complete") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => val dstream = new RateTestInputDStream(ssc) dstream.register() ssc.start() @@ -43,8 +42,7 @@ class RateControllerSuite extends TestSuiteBase { } test("ReceiverRateController - published rates reach receivers") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => val estimator = new ConstantEstimator(100) val dstream = new RateTestInputDStream(ssc) { override val rateController = diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 56b4008..10f92f9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,24 +22,18 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{LocalStreamingContext, _} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { +class StreamingJobProgressListenerSuite + extends TestSuiteBase + with LocalStreamingContext + with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) - var ssc: StreamingContext = _ - - override def afterFunction() { - super.afterFunction() - if (ssc != null) { - ssc.stop() - } - } - private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org