advancedxy commented on code in PR #211:
URL: 
https://github.com/apache/arrow-datafusion-comet/pull/211#discussion_r1535186578


##########
spark/src/main/scala/org/apache/spark/sql/comet/operators.scala:
##########
@@ -588,6 +589,41 @@ case class CometHashAggregateExec(
     Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, 
child)
 }
 
+case class CometSortMergeJoinExec(

Review Comment:
   ditto



##########
spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala:
##########
@@ -257,9 +257,28 @@ class CometBatchRDD(
   }
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
-    val partition = split.asInstanceOf[CometBatchPartition]
+    new Iterator[ColumnarBatch] {
+      val partition = split.asInstanceOf[CometBatchPartition]
+      val batchesIter = 
partition.value.value.map(CometExec.decodeBatches(_)).toIterator
+      var iter: Iterator[ColumnarBatch] = null
+
+      override def hasNext: Boolean = {
+        if (iter != null) {
+          if (iter.hasNext) {
+            return true
+          }
+        }
+        if (batchesIter.hasNext) {
+          iter = batchesIter.next()
+          return iter.hasNext
+        }
+        false
+      }
 
-    partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator

Review Comment:
   hmmm, I think you can simply avoid this by converting to iterator first, 
such as :
   
   ```scala
   partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
   ```



##########
spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala:
##########
@@ -335,6 +335,26 @@ class CometSparkSessionExtensions
               op
           }
 
+        case op: SortMergeJoinExec

Review Comment:
   Looks unrelated change.
   
   Could you add the `BroadcastHashJoinExec` case after `SortMergeJoinExec`, so 
that this unrelated change would not be included.



##########
common/src/main/scala/org/apache/comet/vector/NativeUtil.scala:
##########
@@ -46,29 +48,27 @@ class NativeUtil {
    *   the output stream
    */
   def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): 
Long = {
-    var schemaRoot: Option[VectorSchemaRoot] = None
-    var writer: Option[ArrowStreamWriter] = None
+    var writer: Option[CometArrowStreamWriter] = None
     var rowCount = 0
 
     batches.foreach { batch =>
       val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
-      val root = schemaRoot.getOrElse(new 
VectorSchemaRoot(fieldVectors.asJava))
+      val root = new VectorSchemaRoot(fieldVectors.asJava)
       val provider = batchProviderOpt.getOrElse(dictionaryProvider)

Review Comment:
   Look deeper, what if incoming batches have different dictionary provider?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to