viirya commented on code in PR #890:
URL: https://github.com/apache/datafusion-comet/pull/890#discussion_r1736798579


##########
common/src/main/scala/org/apache/comet/vector/NativeUtil.scala:
##########
@@ -44,43 +44,53 @@ class NativeUtil {
    * @param batch
    *   the input Comet columnar batch
    * @return
-   *   a list containing number of rows + pairs of memory addresses in the 
format of (address of
-   *   Arrow array, address of Arrow schema)
+   *   an exported batches object containing an array containing number of 
rows + pairs of memory
+   *   addresses in the format of (address of Arrow array, address of Arrow 
schema)
    */
-  def exportBatch(batch: ColumnarBatch): Array[Long] = {
+  def exportBatch(batch: ColumnarBatch): ExportedBatch = {
     val exportedVectors = mutable.ArrayBuffer.empty[Long]
     exportedVectors += batch.numRows()
 
+    // Run checks prior to exporting the batch
+    (0 until batch.numCols()).foreach { index =>
+      val c = batch.column(index)
+      if (!c.isInstanceOf[CometVector]) {
+        batch.close()
+        throw new SparkException(
+          "Comet execution only takes Arrow Arrays, but got " +
+            s"${c.getClass}")
+      }
+    }
+
+    val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema]
+    val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray]
+
     (0 until batch.numCols()).foreach { index =>
-      batch.column(index) match {
-        case a: CometVector =>
-          val valueVector = a.getValueVector
-
-          val provider = if (valueVector.getField.getDictionary != null) {
-            a.getDictionaryProvider
-          } else {
-            null
-          }
-
-          val arrowSchema = ArrowSchema.allocateNew(allocator)
-          val arrowArray = ArrowArray.allocateNew(allocator)
-          Data.exportVector(
-            allocator,
-            getFieldVector(valueVector, "export"),
-            provider,
-            arrowArray,
-            arrowSchema)
-
-          exportedVectors += arrowArray.memoryAddress()
-          exportedVectors += arrowSchema.memoryAddress()
-        case c =>
-          throw new SparkException(
-            "Comet execution only takes Arrow Arrays, but got " +
-              s"${c.getClass}")
+      val cometVector = batch.column(index).asInstanceOf[CometVector]
+      val valueVector = cometVector.getValueVector
+
+      val provider = if (valueVector.getField.getDictionary != null) {
+        cometVector.getDictionaryProvider
+      } else {
+        null
       }
+
+      val arrowSchema = ArrowSchema.allocateNew(allocator)
+      val arrowArray = ArrowArray.allocateNew(allocator)
+      arrowSchemas += arrowSchema
+      arrowArrays += arrowArray
+      Data.exportVector(
+        allocator,
+        getFieldVector(valueVector, "export"),
+        provider,
+        arrowArray,
+        arrowSchema)
+
+      exportedVectors += arrowArray.memoryAddress()
+      exportedVectors += arrowSchema.memoryAddress()
     }
 
-    exportedVectors.toArray

Review Comment:
   You can return `ExportedBatch` without any above change.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to