peter-toth commented on code in PR #54330:
URL: https://github.com/apache/spark/pull/54330#discussion_r2826611189
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala:
##########
@@ -56,94 +91,60 @@ class DataSourceRDD(
}
override def compute(split: Partition, context: TaskContext):
Iterator[InternalRow] = {
-
- val iterator = new Iterator[Object] {
- private val inputPartitions = castPartition(split).inputPartitions
- private var currentIter: Option[Iterator[Object]] = None
- private var currentIndex: Int = 0
-
- private val partitionMetricCallback = new
PartitionMetricCallback(customMetrics)
-
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- context.addTaskCompletionListener[Unit] { _ =>
- partitionMetricCallback.execute()
- }
-
- override def hasNext: Boolean = currentIter.exists(_.hasNext) ||
advanceToNextIter()
-
- override def next(): Object = {
- if (!hasNext) throw new NoSuchElementException("No more elements")
- currentIter.get.next()
+ castPartition(split).inputPartition.iterator.flatMap { inputPartition =>
+ val (iter, reader) = if (columnarReads) {
+ val batchReader =
partitionReaderFactory.createColumnarReader(inputPartition)
+ val iter = new MetricsBatchIterator(
+ new PartitionIterator[ColumnarBatch](batchReader, customMetrics),
readerStateThreadLocal)
+ (iter, batchReader)
+ } else {
+ val rowReader = partitionReaderFactory.createReader(inputPartition)
+ val iter = new MetricsRowIterator(
+ new PartitionIterator[InternalRow](rowReader, customMetrics),
readerStateThreadLocal)
+ (iter, rowReader)
}
- private def advanceToNextIter(): Boolean = {
- if (currentIndex >= inputPartitions.length) {
- false
- } else {
- val inputPartition = inputPartitions(currentIndex)
- currentIndex += 1
-
- // TODO: SPARK-25083 remove the type erasure hack in data source scan
- val (iter, reader) = if (columnarReads) {
- val batchReader =
partitionReaderFactory.createColumnarReader(inputPartition)
- val iter = new MetricsBatchIterator(
- new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
- (iter, batchReader)
- } else {
- val rowReader = partitionReaderFactory.createReader(inputPartition)
- val iter = new MetricsRowIterator(
- new PartitionIterator[InternalRow](rowReader, customMetrics))
- (iter, rowReader)
+ // Add completion listener only once per thread (null means no listener
added yet)
+ val readerState = readerStateThreadLocal.get()
+ if (readerState == null) {
+ context.addTaskCompletionListener[Unit] { _ =>
+ // Use the reader and iterator from ThreadLocal (the last ones
created in this thread)
+ val readerState = readerStateThreadLocal.get()
+ if (readerState != null) {
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ CustomMetrics.updateMetrics(
+ readerState.reader.currentMetricsValues.toImmutableArraySeq,
customMetrics)
+ readerState.iterator.forceUpdateMetrics()
+ readerState.reader.close()
}
-
- // Once we advance to the next partition, update the metric callback
for early finish
- val previousMetrics = partitionMetricCallback.advancePartition(iter,
reader)
- previousMetrics.foreach(reader.initMetricsValues)
-
- currentIter = Some(iter)
- hasNext
+ readerStateThreadLocal.remove()
}
+ } else {
+ readerState.metrics.foreach(reader.initMetricsValues)
Review Comment:
Indeed. I refined `ReaderState` and close `readerState.reader` properly in
https://github.com/apache/spark/pull/54330/commits/5ca18c313e5d6dc39c640d74921a68ed5ec22af3.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]