Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/21385#discussion_r190403472
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
---
@@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest {
}
test("blocks waiting for new rows") {
- val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val rdd = new ContinuousShuffleReadRDD(
+ sparkContext, numPartitions = 1, checkpointIntervalMs =
Long.MaxValue)
+ val epoch = rdd.compute(rdd.partitions(0), ctx)
val readRowThread = new Thread {
override def run(): Unit = {
- // set the non-inheritable thread local
- TaskContext.setTaskContext(ctx)
- val epoch = rdd.compute(rdd.partitions(0), ctx)
- epoch.next().getInt(0)
+ try {
+ epoch.next().getInt(0)
+ } catch {
+ case _: InterruptedException => // do nothing - expected at test
ending
+ }
}
}
try {
readRowThread.start()
eventually(timeout(streamingTimeout)) {
- assert(readRowThread.getState == Thread.State.WAITING)
+ assert(readRowThread.getState == Thread.State.TIMED_WAITING)
}
} finally {
readRowThread.interrupt()
readRowThread.join()
}
}
+
+ test("multiple writers") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions =
1, numShuffleWriters = 3)
+ val endpoint =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverRow(0, unsafeRow("writer0-row0")),
+ ReceiverRow(1, unsafeRow("writer1-row0")),
+ ReceiverRow(2, unsafeRow("writer2-row0")),
+ ReceiverEpochMarker(0),
+ ReceiverEpochMarker(1),
+ ReceiverEpochMarker(2)
+ )
+
+ val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
+ assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet ==
+ Set("writer0-row0", "writer1-row0", "writer2-row0"))
+ }
+
+ test("epoch only ends when all writers send markers") {
+ val rdd = new ContinuousShuffleReadRDD(
+ sparkContext, numPartitions = 1, numShuffleWriters = 3,
checkpointIntervalMs = Long.MaxValue)
+ val endpoint =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverRow(0, unsafeRow("writer0-row0")),
+ ReceiverRow(1, unsafeRow("writer1-row0")),
+ ReceiverRow(2, unsafeRow("writer2-row0")),
+ ReceiverEpochMarker(0),
+ ReceiverEpochMarker(2)
+ )
+
+ val epoch = rdd.compute(rdd.partitions(0), ctx)
+ val rows = (0 until 3).map(_ => epoch.next()).toSet
+ assert(rows.map(_.getUTF8String(0).toString) ==
+ Set("writer0-row0", "writer1-row0", "writer2-row0"))
+
+ // After checking the right rows, block until we get an epoch marker
indicating there's no next.
+ // (Also fail the assertion if for some reason we get a row.)
+ val readEpochMarkerThread = new Thread {
+ override def run(): Unit = {
+ assert(!epoch.hasNext)
+ }
+ }
+
+ readEpochMarkerThread.start()
+ eventually(timeout(streamingTimeout)) {
+ assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
--- End diff --
Same question as above ... is this the only possible thread state.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]