Github user zsxwing commented on a diff in the pull request:
https://github.com/apache/spark/pull/19984#discussion_r158134853
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
---
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition,
RowToUnsafeDataReader}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.streaming.ProcessingTime
+import org.apache.spark.util.{SystemClock, ThreadUtils}
+
+class ContinuousDataSourceRDD(
+ sc: SparkContext,
+ sqlContext: SQLContext,
+ @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]])
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ private val dataQueueSize =
sqlContext.conf.continuousStreamingExecutorQueueSize
+ private val epochPollIntervalMs =
sqlContext.conf.continuousStreamingExecutorPollIntervalMs
+
+ override protected def getPartitions: Array[Partition] = {
+ readTasks.asScala.zipWithIndex.map {
+ case (readTask, index) => new DataSourceRDDPartition(index, readTask)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext):
Iterator[UnsafeRow] = {
+ val reader =
split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+
+ val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
+
+ // This queue contains two types of messages:
+ // * (null, null) representing an epoch boundary.
+ // * (row, off) containing a data row and its corresponding
PartitionOffset.
+ val queue = new ArrayBlockingQueue[(UnsafeRow,
PartitionOffset)](dataQueueSize)
+
+ val epochPollFailed = new AtomicBoolean(false)
+ val epochPollExecutor =
ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ s"epoch-poll--${runId}--${context.partitionId()}")
+ val epochPollRunnable = new EpochPollRunnable(queue, context,
epochPollFailed)
+ epochPollExecutor.scheduleWithFixedDelay(
+ epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
+
+ // Important sequencing - we must get start offset before the data
reader thread begins
+ val startOffset =
ContinuousDataSourceRDD.getBaseReader(reader).getOffset
+
+ val dataReaderFailed = new AtomicBoolean(false)
+ val dataReaderThread = new DataReaderThread(reader, queue, context,
dataReaderFailed)
+ dataReaderThread.setDaemon(true)
+ dataReaderThread.start()
+
+ context.addTaskCompletionListener(_ => {
+ reader.close()
+ dataReaderThread.interrupt()
+ epochPollExecutor.shutdown()
+ })
+
+ val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get)
+ new Iterator[UnsafeRow] {
+ private var currentRow: UnsafeRow = _
+ private var currentOffset: PartitionOffset = startOffset
+ private var currentEpoch =
+
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ override def hasNext(): Boolean = {
+ if (dataReaderFailed.get()) {
+ throw new SparkException("data read failed",
dataReaderThread.failureReason)
+ }
+ if (epochPollFailed.get()) {
+ throw new SparkException("epoch poll failed",
epochPollRunnable.failureReason)
+ }
+
+ queue.take() match {
+ // epoch boundary marker
+ case (null, null) =>
+ epochEndpoint.send(ReportPartitionOffset(
+ context.partitionId(),
+ currentEpoch,
+ currentOffset))
+ currentEpoch += 1
+ false
+ // real row
+ case (row, offset) =>
+ currentRow = row
+ currentOffset = offset
+ true
+ }
+ }
+
+ override def next(): UnsafeRow = {
--- End diff --
this method doesn't follow the java Iterator next contract:
```
NoSuchElementException if the iteration has no more elements
```
You can extend `org.apache.spark.util.NextIterator` to fix it.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]