Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21428#discussion_r191615398
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
 ---
    @@ -288,4 +267,153 @@ class ContinuousShuffleReadSuite extends StreamTest {
         val thirdEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
         assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
       }
    +
    +  test("one epoch") {
    +    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions 
= 1)
    +    val writer = new RPCContinuousShuffleWriter(
    +      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
    +
    +    writer.write(Iterator(1, 2, 3))
    +
    +    assert(readEpoch(reader) == Seq(1, 2, 3))
    +  }
    +
    +  test("multiple epochs") {
    +    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions 
= 1)
    +    val writer = new RPCContinuousShuffleWriter(
    +      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
    +
    +    writer.write(Iterator(1, 2, 3))
    +    writer.write(Iterator(4, 5, 6))
    +
    +    assert(readEpoch(reader) == Seq(1, 2, 3))
    +    assert(readEpoch(reader) == Seq(4, 5, 6))
    +  }
    +
    +  test("empty epochs") {
    +    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions 
= 1)
    +    val writer = new RPCContinuousShuffleWriter(
    +      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
    +
    +    writer.write(Iterator())
    +    writer.write(Iterator(1, 2))
    +    writer.write(Iterator())
    +    writer.write(Iterator())
    +    writer.write(Iterator(3, 4))
    +    writer.write(Iterator())
    +
    +    assert(readEpoch(reader) == Seq())
    +    assert(readEpoch(reader) == Seq(1, 2))
    +    assert(readEpoch(reader) == Seq())
    +    assert(readEpoch(reader) == Seq())
    +    assert(readEpoch(reader) == Seq(3, 4))
    +    assert(readEpoch(reader) == Seq())
    +  }
    +
    +  test("blocks waiting for writer") {
    +    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions 
= 1)
    +    val writer = new RPCContinuousShuffleWriter(
    +      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
    +
    +    val readerEpoch = reader.compute(reader.partitions(0), ctx)
    +
    +    val readRowThread = new Thread {
    +      override def run(): Unit = {
    +        assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1))
    +      }
    +    }
    +    readRowThread.start()
    +
    +    eventually(timeout(streamingTimeout)) {
    +      assert(readRowThread.getState == Thread.State.TIMED_WAITING)
    +    }
    +
    +    // Once we write the epoch the thread should stop waiting and succeed.
    +    writer.write(Iterator(1))
    +    readRowThread.join()
    +  }
    +
    +  test("multiple writer partitions") {
    +    val numWriterPartitions = 3
    +
    +    val reader = new ContinuousShuffleReadRDD(
    +      sparkContext, numPartitions = 1, numShuffleWriters = 
numWriterPartitions)
    +    val writers = (0 until 3).map { idx =>
    +      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), 
Array(readRDDEndpoint(reader)))
    +    }
    +
    +    writers(0).write(Iterator(1, 4, 7))
    +    writers(1).write(Iterator(2, 5))
    +    writers(2).write(Iterator(3, 6))
    +
    +    writers(0).write(Iterator(4, 7, 10))
    +    writers(1).write(Iterator(5, 8))
    +    writers(2).write(Iterator(6, 9))
    +
    +    // Since there are multiple asynchronous writers, the original row 
sequencing is not guaranteed.
    +    // The epochs should be deterministically preserved, however.
    +    assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet)
    +    assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet)
    +  }
    +
    +  test("reader epoch only ends when all writer partitions write it") {
    +    val numWriterPartitions = 3
    +
    +    val reader = new ContinuousShuffleReadRDD(
    +      sparkContext, numPartitions = 1, numShuffleWriters = 
numWriterPartitions)
    +    val writers = (0 until 3).map { idx =>
    +      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), 
Array(readRDDEndpoint(reader)))
    +    }
    +
    +    writers(1).write(Iterator())
    +    writers(2).write(Iterator())
    +
    +    val readerEpoch = reader.compute(reader.partitions(0), ctx)
    +
    +    val readEpochMarkerThread = new Thread {
    +      override def run(): Unit = {
    +        assert(!readerEpoch.hasNext)
    +      }
    +    }
    +
    +    readEpochMarkerThread.start()
    +    eventually(timeout(streamingTimeout)) {
    +      assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
    +    }
    +
    +    writers(0).write(Iterator())
    +    readEpochMarkerThread.join()
    +  }
    +
    +  test("receiver stopped with row last") {
    +    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
    +    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
    +    send(
    +      endpoint,
    +      ReceiverEpochMarker(0),
    +      ReceiverRow(0, unsafeRow(111))
    +    )
    +
    +    ctx.markTaskCompleted(None)
    +    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
    +    eventually(timeout(streamingTimeout)) {
    +      
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
    +    }
    +  }
    +
    +  test("receiver stopped with marker last") {
    +    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
    +    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
    +    send(
    +      endpoint,
    +      ReceiverRow(0, unsafeRow(111)),
    +      ReceiverEpochMarker(0)
    +    )
    +
    +    ctx.markTaskCompleted(None)
    +    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
    +    eventually(timeout(streamingTimeout)) {
    +      
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
    +    }
    +  }
    --- End diff --
    
    there isnt a test where a RPCContinuousShuffleWriter writes to multiple 
reader endpoints.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to