Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/21428#discussion_r191615398
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
---
@@ -288,4 +267,153 @@ class ContinuousShuffleReadSuite extends StreamTest {
val thirdEpoch = rdd.compute(rdd.partitions(0),
ctx).map(_.getUTF8String(0).toString).toSet
assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
}
+
+ test("one epoch") {
+ val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions
= 1)
+ val writer = new RPCContinuousShuffleWriter(
+ 0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+ writer.write(Iterator(1, 2, 3))
+
+ assert(readEpoch(reader) == Seq(1, 2, 3))
+ }
+
+ test("multiple epochs") {
+ val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions
= 1)
+ val writer = new RPCContinuousShuffleWriter(
+ 0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+ writer.write(Iterator(1, 2, 3))
+ writer.write(Iterator(4, 5, 6))
+
+ assert(readEpoch(reader) == Seq(1, 2, 3))
+ assert(readEpoch(reader) == Seq(4, 5, 6))
+ }
+
+ test("empty epochs") {
+ val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions
= 1)
+ val writer = new RPCContinuousShuffleWriter(
+ 0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+ writer.write(Iterator())
+ writer.write(Iterator(1, 2))
+ writer.write(Iterator())
+ writer.write(Iterator())
+ writer.write(Iterator(3, 4))
+ writer.write(Iterator())
+
+ assert(readEpoch(reader) == Seq())
+ assert(readEpoch(reader) == Seq(1, 2))
+ assert(readEpoch(reader) == Seq())
+ assert(readEpoch(reader) == Seq())
+ assert(readEpoch(reader) == Seq(3, 4))
+ assert(readEpoch(reader) == Seq())
+ }
+
+ test("blocks waiting for writer") {
+ val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions
= 1)
+ val writer = new RPCContinuousShuffleWriter(
+ 0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+ val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+ val readRowThread = new Thread {
+ override def run(): Unit = {
+ assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1))
+ }
+ }
+ readRowThread.start()
+
+ eventually(timeout(streamingTimeout)) {
+ assert(readRowThread.getState == Thread.State.TIMED_WAITING)
+ }
+
+ // Once we write the epoch the thread should stop waiting and succeed.
+ writer.write(Iterator(1))
+ readRowThread.join()
+ }
+
+ test("multiple writer partitions") {
+ val numWriterPartitions = 3
+
+ val reader = new ContinuousShuffleReadRDD(
+ sparkContext, numPartitions = 1, numShuffleWriters =
numWriterPartitions)
+ val writers = (0 until 3).map { idx =>
+ new RPCContinuousShuffleWriter(idx, new HashPartitioner(1),
Array(readRDDEndpoint(reader)))
+ }
+
+ writers(0).write(Iterator(1, 4, 7))
+ writers(1).write(Iterator(2, 5))
+ writers(2).write(Iterator(3, 6))
+
+ writers(0).write(Iterator(4, 7, 10))
+ writers(1).write(Iterator(5, 8))
+ writers(2).write(Iterator(6, 9))
+
+ // Since there are multiple asynchronous writers, the original row
sequencing is not guaranteed.
+ // The epochs should be deterministically preserved, however.
+ assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet)
+ assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet)
+ }
+
+ test("reader epoch only ends when all writer partitions write it") {
+ val numWriterPartitions = 3
+
+ val reader = new ContinuousShuffleReadRDD(
+ sparkContext, numPartitions = 1, numShuffleWriters =
numWriterPartitions)
+ val writers = (0 until 3).map { idx =>
+ new RPCContinuousShuffleWriter(idx, new HashPartitioner(1),
Array(readRDDEndpoint(reader)))
+ }
+
+ writers(1).write(Iterator())
+ writers(2).write(Iterator())
+
+ val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+ val readEpochMarkerThread = new Thread {
+ override def run(): Unit = {
+ assert(!readerEpoch.hasNext)
+ }
+ }
+
+ readEpochMarkerThread.start()
+ eventually(timeout(streamingTimeout)) {
+ assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
+ }
+
+ writers(0).write(Iterator())
+ readEpochMarkerThread.join()
+ }
+
+ test("receiver stopped with row last") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverEpochMarker(0),
+ ReceiverRow(0, unsafeRow(111))
+ )
+
+ ctx.markTaskCompleted(None)
+ val receiver =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+ eventually(timeout(streamingTimeout)) {
+
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+ }
+ }
+
+ test("receiver stopped with marker last") {
+ val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+ val endpoint =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+ send(
+ endpoint,
+ ReceiverRow(0, unsafeRow(111)),
+ ReceiverEpochMarker(0)
+ )
+
+ ctx.markTaskCompleted(None)
+ val receiver =
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+ eventually(timeout(streamingTimeout)) {
+
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+ }
+ }
--- End diff --
there isnt a test where a RPCContinuousShuffleWriter writes to multiple
reader endpoints.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]