szehon-ho commented on PR #54330:
URL: https://github.com/apache/spark/pull/54330#issuecomment-3930501172

   Thanks for the refactor.  I was actually wondering if this approach works: 
(using cursor generation, please double check if it makes sense).  Its more 
localized to SPJ case then
   
   ```
   class GroupedPartitionsRDD(
       @transient private val dataSourceRDD: DataSourceRDD,
       groupedPartitions: Seq[Seq[Int]]
     ) extends RDD[InternalRow](dataSourceRDD) {
     
   val groupedPartition = split.asInstanceOf[GroupedPartitionsRDDPartition]
     val readers = new ArrayBuffer[PartitionReader[_]]()
     var listenerAdded = false
     
     def addCompletionListener(): Unit = {
       if (!listenerAdded) {
         context.addTaskCompletionListener[Unit] { _ =>
           readers.foreach { reader =>
             try {
               CustomMetrics.updateMetrics(
                 reader.currentMetricsValues.toImmutableArraySeq,
                 dataSourceRDD.customMetrics)
               reader.close()
             } catch {
               case e: Exception =>
                 logWarning(s"Error closing reader: ${e.getMessage}", e)
             }
           }
         }
         listenerAdded = true
       }
     }
     
     // Use a self-closing iterator wrapper
     new Iterator[InternalRow] {
       private val parentIter = groupedPartition.parentIndices.iterator
       private var currentIterator: Iterator[InternalRow] = null
       private var currentReader: PartitionReader[_] = null
       
       private def advance(): Boolean = {
         while (currentIterator == null || !currentIterator.hasNext) {
           if (!parentIter.hasNext) {
             // Close current reader if exists
             if (currentReader != null) {
               try {
                 CustomMetrics.updateMetrics(
                   currentReader.currentMetricsValues.toImmutableArraySeq,
                   dataSourceRDD.customMetrics)
                 currentReader.close()
               } catch {
                 case e: Exception =>
                   logWarning(s"Error closing reader: ${e.getMessage}", e)
               }
               currentReader = null
             }
             return false
           }
           
           // Close previous reader
           if (currentReader != null) {
             try {
               CustomMetrics.updateMetrics(
                 currentReader.currentMetricsValues.toImmutableArraySeq,
                 dataSourceRDD.customMetrics)
               currentReader.close()
             } catch {
               case e: Exception =>
                 logWarning(s"Error closing reader: ${e.getMessage}", e)
             }
           }
           
           val parentIndex = parentIter.next()
           val inputPartitionOpt = dataSourceRDD.inputPartitions(parentIndex)
           
           currentIterator = inputPartitionOpt.iterator.flatMap { 
inputPartition =>
             currentReader = if (dataSourceRDD.columnarReads) {
               
dataSourceRDD.partitionReaderFactory.createColumnarReader(inputPartition)
             } else {
               dataSourceRDD.partitionReaderFactory.createReader(inputPartition)
             }
             
             addCompletionListener()
             
             val iter = if (dataSourceRDD.columnarReads) {
               new MetricsBatchIterator(
                 new PartitionIterator[ColumnarBatch](
                   currentReader.asInstanceOf[PartitionReader[ColumnarBatch]], 
                   dataSourceRDD.customMetrics))
             } else {
               new MetricsRowIterator(
                 new PartitionIterator[InternalRow](
                   currentReader.asInstanceOf[PartitionReader[InternalRow]], 
                   dataSourceRDD.customMetrics))
             }
             
             iter.asInstanceOf[Iterator[InternalRow]]
           }
         }
         true
       }
       
       override def hasNext: Boolean = advance()
       
       override def next(): InternalRow = {
         if (!hasNext) {
           throw new NoSuchElementException("next on empty iterator")
         }
         currentIterator.next()
       }
     }
   }
   
   
   private case class GroupedPartitionsRDDPartition(
       index: Int,
       parentIndices: Array[Int],
       preferredLocation: Option[String] = None
     ) extends Partition
   ```
   that can be used in GroupPartitionsExec like:
   
   ```
   override protected def doExecute(): RDD[InternalRow] = {
     if (groupedPartitions.isEmpty) {
       sparkContext.emptyRDD
     } else {
       child.execute() match {
         case dsRDD: DataSourceRDD =>
           // Use custom RDD that manages all readers in one compute() call
           new GroupedPartitionsRDD(
             dsRDD,
             groupedPartitions.map(_._2))
         case _ => // error or fallback?
       }
     }
   }
   ```
   
   The current code is definitely more Spark-native, reusing coalesceRDD, but 
my doubt is the threadlocal and chance for memory leak like fixed by @viirya :  
https://github.com/apache/spark/pull/51503  But ill defer to others, if people 
like this approach more.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to