Repository: spark Updated Branches: refs/heads/branch-1.6 699f497cf -> f6d866173
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java ---------------------------------------------------------------------- diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java deleted file mode 100644 index eac4cdd..0000000 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Set; - -import scala.Tuple2; - -import com.google.common.base.Optional; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.util.ManualClock; -import org.junit.Assert; -import org.junit.Test; - -import org.apache.spark.HashPartitioner; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.Function4; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaTrackStateDStream; - -public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { - - /** - * This test is only for testing the APIs. It's not necessary to run it. - */ - public void testAPI() { - JavaPairRDD<String, Boolean> initialRDD = null; - JavaPairDStream<String, Integer> wordsDstream = null; - - final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>> - trackStateFunc = - new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() { - - @Override - public Optional<Double> call( - Time time, String word, Optional<Integer> one, State<Boolean> state) { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return Optional.of(2.0); - } - }; - - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream = - wordsDstream.trackStateByKey( - StateSpec.function(trackStateFunc) - .initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); - - JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots(); - - final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 = - new Function2<Optional<Integer>, State<Boolean>, Double>() { - - @Override - public Double call(Optional<Integer> one, State<Boolean> state) { - // Use all State's methods here - state.exists(); - state.get(); - state.isTimingOut(); - state.remove(); - state.update(true); - return 2.0; - } - }; - - JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 = - wordsDstream.trackStateByKey( - StateSpec.<String, Integer, Boolean, Double> function(trackStateFunc2) - .initialState(initialRDD) - .numPartitions(10) - .partitioner(new HashPartitioner(10)) - .timeout(Durations.seconds(10))); - - JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots(); - } - - @Test - public void testBasicFunction() { - List<List<String>> inputData = Arrays.asList( - Collections.<String>emptyList(), - Arrays.asList("a"), - Arrays.asList("a", "b"), - Arrays.asList("a", "b", "c"), - Arrays.asList("a", "b"), - Arrays.asList("a"), - Collections.<String>emptyList() - ); - - List<Set<Integer>> outputData = Arrays.asList( - Collections.<Integer>emptySet(), - Sets.newHashSet(1), - Sets.newHashSet(2, 1), - Sets.newHashSet(3, 2, 1), - Sets.newHashSet(4, 3), - Sets.newHashSet(5), - Collections.<Integer>emptySet() - ); - - List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList( - Collections.<Tuple2<String, Integer>>emptySet(), - Sets.newHashSet(new Tuple2<String, Integer>("a", 1)), - Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)), - Sets.newHashSet( - new Tuple2<String, Integer>("a", 3), - new Tuple2<String, Integer>("b", 2), - new Tuple2<String, Integer>("c", 1)), - Sets.newHashSet( - new Tuple2<String, Integer>("a", 4), - new Tuple2<String, Integer>("b", 3), - new Tuple2<String, Integer>("c", 1)), - Sets.newHashSet( - new Tuple2<String, Integer>("a", 5), - new Tuple2<String, Integer>("b", 3), - new Tuple2<String, Integer>("c", 1)), - Sets.newHashSet( - new Tuple2<String, Integer>("a", 5), - new Tuple2<String, Integer>("b", 3), - new Tuple2<String, Integer>("c", 1)) - ); - - Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc = - new Function2<Optional<Integer>, State<Integer>, Integer>() { - - @Override - public Integer call(Optional<Integer> value, State<Integer> state) throws Exception { - int sum = value.or(0) + (state.exists() ? state.get() : 0); - state.update(sum); - return sum; - } - }; - testOperation( - inputData, - StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc), - outputData, - stateData); - } - - private <K, S, T> void testOperation( - List<List<K>> input, - StateSpec<K, Integer, S, T> trackStateSpec, - List<Set<T>> expectedOutputs, - List<Set<Tuple2<K, S>>> expectedStateSnapshots) { - int numBatches = expectedOutputs.size(); - JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); - JavaTrackStateDStream<K, Integer, S, T> trackeStateStream = - JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() { - @Override - public Tuple2<K, Integer> call(K x) throws Exception { - return new Tuple2<K, Integer>(x, 1); - } - })).trackStateByKey(trackStateSpec); - - final List<Set<T>> collectedOutputs = - Collections.synchronizedList(Lists.<Set<T>>newArrayList()); - trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() { - @Override - public Void call(JavaRDD<T> rdd) throws Exception { - collectedOutputs.add(Sets.newHashSet(rdd.collect())); - return null; - } - }); - final List<Set<Tuple2<K, S>>> collectedStateSnapshots = - Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList()); - trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() { - @Override - public Void call(JavaPairRDD<K, S> rdd) throws Exception { - collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); - return null; - } - }); - BatchCounter batchCounter = new BatchCounter(ssc.ssc()); - ssc.start(); - ((ManualClock) ssc.ssc().scheduler().clock()) - .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); - batchCounter.waitUntilBatchesCompleted(numBatches, 10000); - - Assert.assertEquals(expectedOutputs, collectedOutputs); - Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala new file mode 100644 index 0000000..4b08085 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.reflect.ClassTag + +import org.scalatest.PrivateMethodTester._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class MapWithStateSuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + checkpointDir = Utils.createTempDir("checkpoint") + } + + after { + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) + } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + + override def beforeAll(): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + test("state - get, exists, update, remove, ") { + var state: StateImpl[Int] = null + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOption.getOrElse(-1) === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOption.getOrElse(-1) === -1) + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + state = new StateImpl[Int]() + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state.remove() + testState(None, shouldBeRemoved = true) + + state.wrapTiminoutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + test("mapWithState - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation[String, Int, Int]( + inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - basic operations with advanced API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and mapped type as Double + val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // mapped type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.mappedClass === classOf[Long]) + } + val ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with mapWithState and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + + test("mapWithState - states as mapped data") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - initial states, with nothing returned as from mapping function") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, mapWithStateSpec, outputData, stateData) + } + + test("mapWithState - state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + Some(key) + } else { + state.update(value.get) + None + } + } + + testOperation( + inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData) + } + + test("mapWithState - state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20) + + // b and c should be returned once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + + test("mapWithState - checkpoint durations") { + val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + + try { + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0 + val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc)) + val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + mapWithStateStream.checkpoint(d) + } + mapWithStateStream.register() + ssc.checkpoint(checkpointDir.toString) + ssc.start() // should initialize all the checkpoint durations + assert(mapWithStateStream.checkpointDuration === null) + assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + ssc.stop(stopSparkContext = false) + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } + + + test("mapWithState - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (key: String, value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + mapWithStateStream.checkpoint(checkpointDuration) + mapWithStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ): Unit = { + require(expectedOutputs.size == expectedStateSnapshots.size) + + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, mapWithStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + + // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec) + val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) + (collectedOutputs, collectedStateSnapshots) + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected $debugString" + ) + } + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala deleted file mode 100644 index 1fc320d..0000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ /dev/null @@ -1,581 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import java.io.File - -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} -import scala.reflect.ClassTag - -import org.scalatest.PrivateMethodTester._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} - -import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} - -class TrackStateByKeySuite extends SparkFunSuite - with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { - - private var sc: SparkContext = null - protected var checkpointDir: File = null - protected val batchDuration = Seconds(1) - - before { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") - } - - after { - if (checkpointDir != null) { - Utils.deleteRecursively(checkpointDir) - } - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - } - - override def beforeAll(): Unit = { - val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite") - conf.set("spark.streaming.clock", classOf[ManualClock].getName()) - sc = new SparkContext(conf) - } - - override def afterAll(): Unit = { - if (sc != null) { - sc.stop() - } - } - - test("state - get, exists, update, remove, ") { - var state: StateImpl[Int] = null - - def testState( - expectedData: Option[Int], - shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false, - shouldBeTimingOut: Boolean = false - ): Unit = { - if (expectedData.isDefined) { - assert(state.exists) - assert(state.get() === expectedData.get) - assert(state.getOption() === expectedData) - assert(state.getOption.getOrElse(-1) === expectedData.get) - } else { - assert(!state.exists) - intercept[NoSuchElementException] { - state.get() - } - assert(state.getOption() === None) - assert(state.getOption.getOrElse(-1) === -1) - } - - assert(state.isTimingOut() === shouldBeTimingOut) - if (shouldBeTimingOut) { - intercept[IllegalArgumentException] { - state.remove() - } - intercept[IllegalArgumentException] { - state.update(-1) - } - } - - assert(state.isUpdated() === shouldBeUpdated) - - assert(state.isRemoved() === shouldBeRemoved) - if (shouldBeRemoved) { - intercept[IllegalArgumentException] { - state.remove() - } - intercept[IllegalArgumentException] { - state.update(-1) - } - } - } - - state = new StateImpl[Int]() - testState(None) - - state.wrap(None) - testState(None) - - state.wrap(Some(1)) - testState(Some(1)) - - state.update(2) - testState(Some(2), shouldBeUpdated = true) - - state = new StateImpl[Int]() - state.update(2) - testState(Some(2), shouldBeUpdated = true) - - state.remove() - testState(None, shouldBeRemoved = true) - - state.wrapTiminoutState(3) - testState(Some(3), shouldBeTimingOut = true) - } - - test("trackStateByKey - basic operations with simple API") { - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(), - Seq(1), - Seq(2, 1), - Seq(3, 2, 1), - Seq(4, 3), - Seq(5), - Seq() - ) - - val stateData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - // state maintains running count, and updated count is returned - val trackStateFunc = (value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOption.getOrElse(0) - state.update(sum) - sum - } - - testOperation[String, Int, Int]( - inputData, StateSpec.function(trackStateFunc), outputData, stateData) - } - - test("trackStateByKey - basic operations with advanced API") { - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(), - Seq("aa"), - Seq("aa", "bb"), - Seq("aa", "bb", "cc"), - Seq("aa", "bb"), - Seq("aa"), - Seq() - ) - - val stateData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - // state maintains running count, key string doubled and returned - val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOption.getOrElse(0) - state.update(sum) - Some(key * 2) - } - - testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) - } - - test("trackStateByKey - type inferencing and class tags") { - - // Simple track state function with value as Int, state as Double and emitted type as Double - val simpleFunc = (value: Option[Int], state: State[Double]) => { - 0L - } - - // Advanced track state function with key as String, value as Int, state as Double and - // emitted type as Double - val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { - Some(0L) - } - - def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = { - val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]] - assert(dstreamImpl.keyClass === classOf[String]) - assert(dstreamImpl.valueClass === classOf[Int]) - assert(dstreamImpl.stateClass === classOf[Double]) - assert(dstreamImpl.emittedClass === classOf[Long]) - } - val ssc = new StreamingContext(sc, batchDuration) - val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) - - // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types - val simpleFunctionStateStream1 = inputStream.trackStateByKey( - StateSpec.function(simpleFunc).numPartitions(1)) - testTypes(simpleFunctionStateStream1) - - // Separately defining StateSpec with simple function requires explicitly specifying types - val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) - val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec) - testTypes(simpleFunctionStateStream2) - - // Separately defining StateSpec with advanced function implicitly gets the types - val advFuncSpec1 = StateSpec.function(advancedFunc) - val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1) - testTypes(advFunctionStateStream1) - - // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types - val advFunctionStateStream2 = inputStream.trackStateByKey( - StateSpec.function(simpleFunc).numPartitions(1)) - testTypes(advFunctionStateStream2) - - // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types - val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) - val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2) - testTypes(advFunctionStateStream3) - } - - test("trackStateByKey - states as emitted records") { - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3)), - Seq(("a", 5)), - Seq() - ) - - val stateData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOption.getOrElse(0) - val output = (key, sum) - state.update(sum) - Some(output) - } - - testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) - } - - test("trackStateByKey - initial states, with nothing emitted") { - - val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) - - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) - - val stateData = - Seq( - Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), - Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), - Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), - Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), - Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), - Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), - Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) - ) - - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { - val sum = value.getOrElse(0) + state.getOption.getOrElse(0) - val output = (key, sum) - state.update(sum) - None.asInstanceOf[Option[Int]] - } - - val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState)) - testOperation(inputData, trackStateSpec, outputData, stateData) - } - - test("trackStateByKey - state removing") { - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), // a will be removed - Seq("a", "b", "c"), // b will be removed - Seq("a", "b", "c"), // a and c will be removed - Seq("a", "b"), // b will be removed - Seq("a"), // a will be removed - Seq() - ) - - // States that were removed - val outputData = - Seq( - Seq(), - Seq(), - Seq("a"), - Seq("b"), - Seq("a", "c"), - Seq("b"), - Seq("a"), - Seq() - ) - - val stateData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("b", 1)), - Seq(("a", 1), ("c", 1)), - Seq(("b", 1)), - Seq(("a", 1)), - Seq(), - Seq() - ) - - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { - if (state.exists) { - state.remove() - Some(key) - } else { - state.update(value.get) - None - } - } - - testOperation( - inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData) - } - - test("trackStateByKey - state timing out") { - val inputData = - Seq( - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq(), // c will time out - Seq(), // b will time out - Seq("a") // a will not time out - ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active - - val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { - if (value.isDefined) { - state.update(1) - } - if (state.isTimingOut) { - Some(key) - } else { - None - } - } - - val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( - inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20) - - // b and c should be emitted once each, when they were marked as expired - assert(collectedOutputs.flatten.sorted === Seq("b", "c")) - - // States for a, b, c should be defined at one point of time - assert(collectedStateSnapshots.exists { - _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) - }) - - // Finally state should be defined only for a - assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) - } - - test("trackStateByKey - checkpoint durations") { - val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) - - def testCheckpointDuration( - batchDuration: Duration, - expectedCheckpointDuration: Duration, - explicitCheckpointDuration: Option[Duration] = None - ): Unit = { - val ssc = new StreamingContext(sc, batchDuration) - - try { - val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) - val dummyFunc = (value: Option[Int], state: State[Int]) => 0 - val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) - val internalTrackStateStream = trackStateStream invokePrivate privateMethod() - - explicitCheckpointDuration.foreach { d => - trackStateStream.checkpoint(d) - } - trackStateStream.register() - ssc.checkpoint(checkpointDir.toString) - ssc.start() // should initialize all the checkpoint durations - assert(trackStateStream.checkpointDuration === null) - assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) - } finally { - ssc.stop(stopSparkContext = false) - } - } - - testCheckpointDuration(Milliseconds(100), Seconds(1)) - testCheckpointDuration(Seconds(1), Seconds(10)) - testCheckpointDuration(Seconds(10), Seconds(100)) - - testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) - testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) - testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) - } - - - test("trackStateByKey - driver failure recovery") { - val inputData = - Seq( - Seq(), - Seq("a"), - Seq("a", "b"), - Seq("a", "b", "c"), - Seq("a", "b"), - Seq("a"), - Seq() - ) - - val stateData = - Seq( - Seq(), - Seq(("a", 1)), - Seq(("a", 2), ("b", 1)), - Seq(("a", 3), ("b", 2), ("c", 1)), - Seq(("a", 4), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)), - Seq(("a", 5), ("b", 3), ("c", 1)) - ) - - def operation(dstream: DStream[String]): DStream[(String, Int)] = { - - val checkpointDuration = batchDuration * (stateData.size / 2) - - val runningCount = (value: Option[Int], state: State[Int]) => { - state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) - state.get() - } - - val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( - StateSpec.function(runningCount)) - // Set internval make sure there is one RDD checkpointing - trackStateStream.checkpoint(checkpointDuration) - trackStateStream.stateSnapshots() - } - - testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, - batchDuration = batchDuration, stopSparkContextAfterTest = false) - } - - private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( - input: Seq[Seq[K]], - trackStateSpec: StateSpec[K, Int, S, T], - expectedOutputs: Seq[Seq[T]], - expectedStateSnapshots: Seq[Seq[(K, S)]] - ): Unit = { - require(expectedOutputs.size == expectedStateSnapshots.size) - - val (collectedOutputs, collectedStateSnapshots) = - getOperationOutput(input, trackStateSpec, expectedOutputs.size) - assert(expectedOutputs, collectedOutputs, "outputs") - assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") - } - - private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( - input: Seq[Seq[K]], - trackStateSpec: StateSpec[K, Int, S, T], - numBatches: Int - ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { - - // Setup the stream computation - val ssc = new StreamingContext(sc, Seconds(1)) - val inputStream = new TestInputStream(ssc, input, numPartitions = 2) - val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) - val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] - val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) - val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] - val stateSnapshotStream = new TestOutputStream( - trackeStateStream.stateSnapshots(), collectedStateSnapshots) - outputStream.register() - stateSnapshotStream.register() - - val batchCounter = new BatchCounter(ssc) - ssc.checkpoint(checkpointDir.toString) - ssc.start() - - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - clock.advance(batchDuration.milliseconds * numBatches) - - batchCounter.waitUntilBatchesCompleted(numBatches, 10000) - ssc.stop(stopSparkContext = false) - (collectedOutputs, collectedStateSnapshots) - } - - private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { - val debugString = "\nExpected:\n" + expected.mkString("\n") + - "\nCollected:\n" + collected.mkString("\n") - assert(expected.size === collected.size, - s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + - debugString) - expected.zip(collected).foreach { case (c, e) => - assert(c.toSet === e.toSet, - s"collected $typ is different from expected $debugString" - ) - } - } -} - http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala new file mode 100644 index 0000000..aa95bd3 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils + +class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { + + private var sc: SparkContext = null + private var checkpointDir: File = _ + + override def beforeAll(): Unit = { + sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + Utils.deleteRecursively(checkpointDir) + } + + override def sparkContext: SparkContext = sc + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, Time(123)) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("updating state and generating mapped data in MapWithStateRDDRecord") { + + val initialTime = 1000L + val updatedTime = 2000L + val thresholdTime = 1500L + @volatile var functionCalled = false + + /** + * Assert that applying given data on a prior record generates correct updated record, with + * correct state map and mapped data + */ + def assertRecordUpdate( + initStates: Iterable[Int], + data: Iterable[String], + expectedStates: Iterable[(Int, Long)], + timeoutThreshold: Option[Long] = None, + removeTimedoutData: Boolean = false, + expectedOutput: Iterable[Int] = None, + expectedTimingOutStates: Iterable[Int] = None, + expectedRemovedStates: Iterable[Int] = None + ): Unit = { + val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() + initStates.foreach { s => initialStateMap.put("key", s, initialTime) } + functionCalled = false + val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val dataIterator = data.map { v => ("key", v) }.iterator + val removedStates = new ArrayBuffer[Int] + val timingOutStates = new ArrayBuffer[Int] + /** + * Mapping function that updates/removes state based on instructions in the data, and + * return state (when instructed or when state is timing out). + */ + def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { + functionCalled = true + + assert(t.milliseconds === updatedTime, "mapping func called with wrong time") + + data match { + case Some("noop") => + None + case Some("get-state") => + Some(state.getOption().getOrElse(-1)) + case Some("update-state") => + if (state.exists) state.update(state.get + 1) else state.update(0) + None + case Some("remove-state") => + removedStates += state.get() + state.remove() + None + case None => + assert(state.isTimingOut() === true, "State is not timing out when data = None") + timingOutStates += state.get() + None + case _ => + fail("Unexpected test data") + } + } + + val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + Some(record), dataIterator, testFunc, + Time(updatedTime), timeoutThreshold, removeTimedoutData) + + val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } + assert(updatedStateData.toSet === expectedStates.toSet, + "states do not match after updating the MapWithStateRDDRecord") + + assert(updatedRecord.mappedData.toSet === expectedOutput.toSet, + "mapped data do not match after updating the MapWithStateRDDRecord") + + assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + } + + // No data, no state should be changed, function should not be called, + assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) + assert(functionCalled === false) + assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) + assert(functionCalled === false) + + // Data present, function should be called irrespective of whether state exists + assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), + expectedStates = Seq((0, initialTime))) + assert(functionCalled === true) + assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) + assert(functionCalled === true) + + // Function called with right state data + assertRecordUpdate(initStates = None, data = Seq("get-state"), + expectedStates = None, expectedOutput = Seq(-1)) + assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), + expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) + + // Update state and timestamp, when timeout not present + assertRecordUpdate(initStates = Nil, data = Seq("update-state"), + expectedStates = Seq((0, updatedTime))) + assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), + expectedStates = Seq((1, updatedTime))) + + // Remove state + assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), + expectedStates = Nil, expectedRemovedStates = Seq(345)) + + // State strictly older than timeout threshold should be timed out + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime), removeTimedoutData = true, + expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) + + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Seq(123)) + + // State should not be timed out after it has received data + assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) + assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + + } + + test("states generated by MapWithStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, Time(initTime)).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: MapWithStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = { + + // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only touch which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + MapWithStateRDDSuite.touchedStateKeys.clear() + + val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + MapWithStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( // should remove k3's state + rdd7, Seq(("k3", 2)), Set()) + } + + test("checkpointing") { + /** + * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs + * - the data RDD and the parent MapWithStateRDD. + */ + def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) } + .collect.toSet + } + + /** Generate MapWithStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : MapWithStateRDD[Int, Int, Int, Int] = { + MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate MapWithStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = { + + // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent + new MapWithStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + + test("checkpointing empty state RDD") { + val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + + /** Assert whether the `mapWithState` operation generates expected results */ + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: MapWithStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T], + doFullScan: Boolean = false + ): MapWithStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new MapWithStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None) + if (doFullScan) newStateRDD.setFullScan() + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist().count() + assertRDD(newStateRDD, expectedStates, expectedMappedData) + newStateRDD + } + + /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */ + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + stateRDD: MapWithStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T]): Unit = { + val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet + assert(states === expectedStates, + "states after mapWithState operation were not as expected") + assert(mappedData === expectedMappedData, + "mapped data after mapWithState operation were not as expected") + } +} + +object MapWithStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d86617/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala ---------------------------------------------------------------------- diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala deleted file mode 100644 index 3b2d43f..0000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ /dev/null @@ -1,389 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.rdd - -import java.io.File - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.util.OpenHashMapBasedStateMap -import org.apache.spark.streaming.{State, Time} -import org.apache.spark.util.Utils - -class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { - - private var sc: SparkContext = null - private var checkpointDir: File = _ - - override def beforeAll(): Unit = { - sc = new SparkContext( - new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) - checkpointDir = Utils.createTempDir() - sc.setCheckpointDir(checkpointDir.toString) - } - - override def afterAll(): Unit = { - if (sc != null) { - sc.stop() - } - Utils.deleteRecursively(checkpointDir) - } - - override def sparkContext: SparkContext = sc - - test("creation from pair RDD") { - val data = Seq((1, "1"), (2, "2"), (3, "3")) - val partitioner = new HashPartitioner(10) - val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( - sc.parallelize(data), partitioner, Time(123)) - assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) - assert(rdd.partitions.size === partitioner.numPartitions) - - assert(rdd.partitioner === Some(partitioner)) - } - - test("updating state and generating emitted data in TrackStateRecord") { - - val initialTime = 1000L - val updatedTime = 2000L - val thresholdTime = 1500L - @volatile var functionCalled = false - - /** - * Assert that applying given data on a prior record generates correct updated record, with - * correct state map and emitted data - */ - def assertRecordUpdate( - initStates: Iterable[Int], - data: Iterable[String], - expectedStates: Iterable[(Int, Long)], - timeoutThreshold: Option[Long] = None, - removeTimedoutData: Boolean = false, - expectedOutput: Iterable[Int] = None, - expectedTimingOutStates: Iterable[Int] = None, - expectedRemovedStates: Iterable[Int] = None - ): Unit = { - val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() - initStates.foreach { s => initialStateMap.put("key", s, initialTime) } - functionCalled = false - val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) - val dataIterator = data.map { v => ("key", v) }.iterator - val removedStates = new ArrayBuffer[Int] - val timingOutStates = new ArrayBuffer[Int] - /** - * Tracking function that updates/removes state based on instructions in the data, and - * return state (when instructed or when state is timing out). - */ - def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { - functionCalled = true - - assert(t.milliseconds === updatedTime, "tracking func called with wrong time") - - data match { - case Some("noop") => - None - case Some("get-state") => - Some(state.getOption().getOrElse(-1)) - case Some("update-state") => - if (state.exists) state.update(state.get + 1) else state.update(0) - None - case Some("remove-state") => - removedStates += state.get() - state.remove() - None - case None => - assert(state.isTimingOut() === true, "State is not timing out when data = None") - timingOutStates += state.get() - None - case _ => - fail("Unexpected test data") - } - } - - val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int]( - Some(record), dataIterator, testFunc, - Time(updatedTime), timeoutThreshold, removeTimedoutData) - - val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } - assert(updatedStateData.toSet === expectedStates.toSet, - "states do not match after updating the TrackStateRecord") - - assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet, - "emitted data do not match after updating the TrackStateRecord") - - assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + - "match those that were expected to do so while updating the TrackStateRecord") - - assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + - "match those that were expected to do so while updating the TrackStateRecord") - - } - - // No data, no state should be changed, function should not be called, - assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) - assert(functionCalled === false) - assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) - assert(functionCalled === false) - - // Data present, function should be called irrespective of whether state exists - assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), - expectedStates = Seq((0, initialTime))) - assert(functionCalled === true) - assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) - assert(functionCalled === true) - - // Function called with right state data - assertRecordUpdate(initStates = None, data = Seq("get-state"), - expectedStates = None, expectedOutput = Seq(-1)) - assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), - expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) - - // Update state and timestamp, when timeout not present - assertRecordUpdate(initStates = Nil, data = Seq("update-state"), - expectedStates = Seq((0, updatedTime))) - assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), - expectedStates = Seq((1, updatedTime))) - - // Remove state - assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), - expectedStates = Nil, expectedRemovedStates = Seq(345)) - - // State strictly older than timeout threshold should be timed out - assertRecordUpdate(initStates = Seq(123), data = Nil, - timeoutThreshold = Some(initialTime), removeTimedoutData = true, - expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) - - assertRecordUpdate(initStates = Seq(123), data = Nil, - timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, - expectedStates = Nil, expectedTimingOutStates = Seq(123)) - - // State should not be timed out after it has received data - assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), - timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, - expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) - assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), - timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, - expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) - - } - - test("states generated by TrackStateRDD") { - val initStates = Seq(("k1", 0), ("k2", 0)) - val initTime = 123 - val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet - val partitioner = new HashPartitioner(2) - val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( - sc.parallelize(initStates), partitioner, Time(initTime)).persist() - assertRDD(initStateRDD, initStateWthTime, Set.empty) - - val updateTime = 345 - - /** - * Test that the test state RDD, when operated with new data, - * creates a new state RDD with expected states - */ - def testStateUpdates( - testStateRDD: TrackStateRDD[String, Int, Int, Int], - testData: Seq[(String, Int)], - expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = { - - // Persist the test TrackStateRDD so that its not recomputed while doing the next operation. - // This is to make sure that we only track which state keys are being touched in the next op. - testStateRDD.persist().count() - - // To track which keys are being touched - TrackStateRDDSuite.touchedStateKeys.clear() - - val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => { - - // Track the key that has been touched - TrackStateRDDSuite.touchedStateKeys += key - - // If the data is 0, do not do anything with the state - // else if the data is 1, increment the state if it exists, or set new state to 0 - // else if the data is 2, remove the state if it exists - data match { - case Some(1) => - if (state.exists()) { state.update(state.get + 1) } - else state.update(0) - case Some(2) => - state.remove() - case _ => - } - None.asInstanceOf[Option[Int]] // Do not return anything, not being tested - } - val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) - - // Assert that the new state RDD has expected state data - val newStateRDD = assertOperation( - testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty) - - // Assert that the function was called only for the keys present in the data - assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size, - "More number of keys are being touched than that is expected") - assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, - "Keys not in the data are being touched unexpectedly") - - // Assert that the test RDD's data has not changed - assertRDD(initStateRDD, initStateWthTime, Set.empty) - newStateRDD - } - - // Test no-op, no state should change - testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state - testStateUpdates( - initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state - testStateUpdates( - initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state - - // Test creation of new state - val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 - Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) - - val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 - Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) - - // Test updating of state - val rdd3 = testStateUpdates( - initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 - Set(("k1", 1, updateTime), ("k2", 0, initTime))) - - val rdd4 = testStateUpdates(rdd3, - Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 - Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) - - val rdd5 = testStateUpdates( - rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 - Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) - - // Test removing of state - val rdd6 = testStateUpdates( // should remove k1's state - initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) - - val rdd7 = testStateUpdates( // should remove k2's state - rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) - - val rdd8 = testStateUpdates( // should remove k3's state - rdd7, Seq(("k3", 2)), Set()) - } - - test("checkpointing") { - /** - * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - - * the data RDD and the parent TrackStateRDD. - */ - def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) - : Set[(List[(Int, Int, Long)], List[Int])] = { - rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } - .collect.toSet - } - - /** Generate TrackStateRDD with data RDD having a long lineage */ - def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) - : TrackStateRDD[Int, Int, Int, Int] = { - TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) - } - - testRDD( - makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) - testRDDPartitions( - makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) - - /** Generate TrackStateRDD with parent state RDD having a long lineage */ - def makeStateRDDWithLongLineageParenttateRDD( - longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { - - // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage - val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) - - // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent - new TrackStateRDD[Int, Int, Int, Int]( - stateRDDWithLongLineage, - stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), - (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, - Time(10), - None - ) - } - - testRDD( - makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) - testRDDPartitions( - makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) - } - - test("checkpointing empty state RDD") { - val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( - sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) - emptyStateRDD.checkpoint() - assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) - val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( - emptyStateRDD.getCheckpointFile.get) - assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) - } - - /** Assert whether the `trackStateByKey` operation generates expected results */ - private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - testStateRDD: TrackStateRDD[K, V, S, T], - newDataRDD: RDD[(K, V)], - trackStateFunc: (Time, K, Option[V], State[S]) => Option[T], - currentTime: Long, - expectedStates: Set[(K, S, Int)], - expectedEmittedRecords: Set[T], - doFullScan: Boolean = false - ): TrackStateRDD[K, V, S, T] = { - - val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { - newDataRDD.partitionBy(testStateRDD.partitioner.get) - } else { - newDataRDD - } - - val newStateRDD = new TrackStateRDD[K, V, S, T]( - testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None) - if (doFullScan) newStateRDD.setFullScan() - - // Persist to make sure that it gets computed only once and we can track precisely how many - // state keys the computing touched - newStateRDD.persist().count() - assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) - newStateRDD - } - - /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */ - private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - trackStateRDD: TrackStateRDD[K, V, S, T], - expectedStates: Set[(K, S, Int)], - expectedEmittedRecords: Set[T]): Unit = { - val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet - val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet - assert(states === expectedStates, - "states after track state operation were not as expected") - assert(emittedRecords === expectedEmittedRecords, - "emitted records after track state operation were not as expected") - } -} - -object TrackStateRDDSuite { - private val touchedStateKeys = new ArrayBuffer[String]() -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org