advancedxy commented on code in PR #206:
URL: 
https://github.com/apache/arrow-datafusion-comet/pull/206#discussion_r1539191962


##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala:
##########
@@ -0,0 +1,466 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+private[arrow] object ArrowWriter {

Review Comment:
   Of course, let me add some comments.



##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala:
##########
@@ -0,0 +1,466 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+private[arrow] object ArrowWriter {
+  def create(root: VectorSchemaRoot): ArrowWriter = {
+    val children = root.getFieldVectors().asScala.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+    val field = vector.getField()
+    (Utils.fromArrowField(field), vector) match {
+      case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+      case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+      case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+      case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+      case (LongType, vector: BigIntVector) => new LongWriter(vector)
+      case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+      case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+      case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+        new DecimalWriter(vector, precision, scale)
+      case (StringType, vector: VarCharVector) => new StringWriter(vector)
+      case (StringType, vector: LargeVarCharVector) => new 
LargeStringWriter(vector)
+      case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+      case (BinaryType, vector: LargeVarBinaryVector) => new 
LargeBinaryWriter(vector)
+      case (DateType, vector: DateDayVector) => new DateWriter(vector)
+      case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
+      case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)

Review Comment:
   Hmmm, although the TimestampNTZType is removed in Spark 3.2 after this PR: 
`https://github.com/apache/spark/pull/33444`. It's still possible to access 
`TimestampNTZType` type, which makes the code a lot cleaner for later versions 
of Spark: we don't have to add special directories for spark 3.2/3.3/3.4 etc.
   
   Per my understanding, since Spark 3.2 will not produce schema with 
`TimestampNTZType`, this pattern match case is effective a no-op case. And it 
could be effective for spark 3.3 and spark 3.4.
   
   



##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala:
##########
@@ -0,0 +1,466 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+private[arrow] object ArrowWriter {
+  def create(root: VectorSchemaRoot): ArrowWriter = {
+    val children = root.getFieldVectors().asScala.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+    val field = vector.getField()
+    (Utils.fromArrowField(field), vector) match {
+      case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+      case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+      case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+      case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+      case (LongType, vector: BigIntVector) => new LongWriter(vector)
+      case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+      case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+      case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+        new DecimalWriter(vector, precision, scale)
+      case (StringType, vector: VarCharVector) => new StringWriter(vector)
+      case (StringType, vector: LargeVarCharVector) => new 
LargeStringWriter(vector)
+      case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+      case (BinaryType, vector: LargeVarBinaryVector) => new 
LargeBinaryWriter(vector)
+      case (DateType, vector: DateDayVector) => new DateWriter(vector)
+      case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
+      case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)
+      case (ArrayType(_, _), vector: ListVector) =>
+        val elementVector = createFieldWriter(vector.getDataVector())
+        new ArrayWriter(vector, elementVector)
+      case (MapType(_, _, _), vector: MapVector) =>
+        val structVector = vector.getDataVector.asInstanceOf[StructVector]
+        val keyWriter = 
createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
+        val valueWriter = 
createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
+        new MapWriter(vector, structVector, keyWriter, valueWriter)
+      case (StructType(_), vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new StructWriter(vector, children.toArray)
+      case (NullType, vector: NullVector) => new NullWriter(vector)
+      case (_: YearMonthIntervalType, vector: IntervalYearVector) =>

Review Comment:
   Hmm, I believe this ArrowWriter doesn't produce dictionary encoded 
ColumnVectors. 
   But I do believe we should match how 
`org.apache.spark.sql.comet.util.Utils#toArrowType` matches SparkType to 
ArrowType.  Let me comment out the following pattern match cases.



##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.vector.NativeUtil
+
+object CometArrowConverters extends Logging {
+  // TODO: we should reuse the same root allocator in the comet code base?
+  val rootAllocator: BufferAllocator = new RootAllocator(Long.MaxValue)
+
+  // This is similar how Spark converts internal row to Arrow format except 
that it is transforming
+  // the result batch to Comet's ColumnarBatch instead of serialized bytes.
+  // There's another big difference that Comet may consume the ColumnarBatch 
by exporting it to
+  // the native side. Hence, we need to:
+  // 1. reset the Arrow writer after the ColumnarBatch is consumed
+  // 2. close the allocator when the task is finished but not when the 
iterator is all consumed
+  // The reason for the second point is that when ColumnarBatch is exported to 
the native side, the
+  // exported process increases the reference count of the Arrow vectors. The 
reference count is
+  // only decreased when the native plan is done with the vectors, which is 
usually longer than
+  // all the ColumnarBatches are consumed.
+  private[sql] class ArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      timeZoneId: String,
+      context: TaskContext)
+      extends Iterator[ColumnarBatch]
+      with AutoCloseable {
+
+    private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId)
+    // Reuse the same root allocator here.
+    private val allocator =
+      rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, 
Long.MaxValue)
+    private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    private val arrowWriter = ArrowWriter.create(root)
+
+    private var currentBatch: ColumnarBatch = null
+    private var closed: Boolean = false
+
+    Option(context).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        close(true)
+      }
+    }
+
+    override def hasNext: Boolean = rowIter.hasNext || {
+      close(false)

Review Comment:
   > I wonder why we need to call close here and whether just calling close in 
the TaskCompletionListener is sufficient.
   
   It might not be sufficient to close in `TaskCompletionListener`  as the task 
might live much longer than the iterator, for example, a task contains `Range 
-> CometRowToColumnar --> Sort --> ShuffleWrite`, the `CometRowToColumnar` will 
be drained much earlier than Sort or ShuffleWrite. So we need to close the 
iterator earlier to make sure there's no memory buffering/(leaking).
   
   However due to comments in 
https://github.com/apache/arrow-datafusion-comet/pull/206/files#diff-04037044481f9a656275a63ebb6a3a63badf866f19700d4a6909d2e17c8d7b72R37-R46,
 we cannot close the allocator after the iterator is consumed. It's already 
exported in the native side, it might be dropped later than the iterator 
consumption. So I add the `allocator.close` in the task completion callback.



##########
common/src/main/scala/org/apache/comet/CometConf.scala:
##########
@@ -336,6 +336,25 @@ object CometConf {
         "enabled when reading from Iceberg tables.")
     .booleanConf
     .createWithDefault(false)
+
+  val COMET_ROW_TO_COLUMNAR_ENABLED: ConfigEntry[Boolean] = conf(
+    "spark.comet.rowToColumnar.enabled")
+    .internal()
+    .doc(
+      "Whether to enable row to columnar conversion for LeafExecNodes. When 
this is turned on, " +

Review Comment:
   Ah, yeah. The rowToColumnar operator could be general enough to convert any 
row-based operators to columnar based.
   
   When I was implementing this feature, I was thinking that we should convert 
the leaf node first so that Comet has the best chance to convert the whole 
operator tree.
   
   Let me re-think this part and a bit of more work should be done to retract 
the `CometRowToColumnar` operator if the whole plan cannot be converted.



##########
spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala:
##########
@@ -626,6 +635,15 @@ object CometSparkSessionExtensions extends Logging {
     op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
   }
 
+  private def applyRowToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {

Review Comment:
   Ok, sounds better to me.



##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.vector.NativeUtil
+
+object CometArrowConverters extends Logging {
+  // TODO: we should reuse the same root allocator in the comet code base?
+  val rootAllocator: BufferAllocator = new RootAllocator(Long.MaxValue)
+
+  // This is similar how Spark converts internal row to Arrow format except 
that it is transforming
+  // the result batch to Comet's ColumnarBatch instead of serialized bytes.
+  // There's another big difference that Comet may consume the ColumnarBatch 
by exporting it to
+  // the native side. Hence, we need to:
+  // 1. reset the Arrow writer after the ColumnarBatch is consumed
+  // 2. close the allocator when the task is finished but not when the 
iterator is all consumed
+  // The reason for the second point is that when ColumnarBatch is exported to 
the native side, the
+  // exported process increases the reference count of the Arrow vectors. The 
reference count is
+  // only decreased when the native plan is done with the vectors, which is 
usually longer than
+  // all the ColumnarBatches are consumed.
+  private[sql] class ArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      timeZoneId: String,
+      context: TaskContext)
+      extends Iterator[ColumnarBatch]
+      with AutoCloseable {
+
+    private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId)
+    // Reuse the same root allocator here.
+    private val allocator =
+      rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, 
Long.MaxValue)
+    private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    private val arrowWriter = ArrowWriter.create(root)
+
+    private var currentBatch: ColumnarBatch = null
+    private var closed: Boolean = false
+
+    Option(context).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        close(true)
+      }
+    }
+
+    override def hasNext: Boolean = rowIter.hasNext || {
+      close(false)
+      false
+    }
+
+    override def next(): ColumnarBatch = {
+      currentBatch = nextBatch()
+      currentBatch
+    }
+
+    override def close(): Unit = {
+      close(false)
+    }
+
+    private def nextBatch(): ColumnarBatch = {
+      if (rowIter.hasNext) {
+        // the arrow writer shall be reset before writing the next batch
+        arrowWriter.reset()
+        var rowCount = 0L
+        while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+          val row = rowIter.next()
+          arrowWriter.write(row)
+          rowCount += 1
+        }
+        arrowWriter.finish()
+        NativeUtil.rootAsBatch(root)
+      } else {
+        null
+      }
+    }
+
+    private def close(closeAllocator: Boolean): Unit = {
+      try {
+        if (!closed) {
+          if (currentBatch != null) {
+            arrowWriter.reset()

Review Comment:
   There's no close method for `arrowWriter`. The `ColumnarBatch` shall be 
closed to close all the `ValueVector`s, which is already achieved  by 
`root.close`



##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala:
##########
@@ -0,0 +1,466 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+private[arrow] object ArrowWriter {
+  def create(root: VectorSchemaRoot): ArrowWriter = {
+    val children = root.getFieldVectors().asScala.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+    val field = vector.getField()
+    (Utils.fromArrowField(field), vector) match {
+      case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+      case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+      case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+      case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+      case (LongType, vector: BigIntVector) => new LongWriter(vector)
+      case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+      case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+      case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+        new DecimalWriter(vector, precision, scale)
+      case (StringType, vector: VarCharVector) => new StringWriter(vector)
+      case (StringType, vector: LargeVarCharVector) => new 
LargeStringWriter(vector)
+      case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+      case (BinaryType, vector: LargeVarBinaryVector) => new 
LargeBinaryWriter(vector)
+      case (DateType, vector: DateDayVector) => new DateWriter(vector)
+      case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
+      case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)
+      case (ArrayType(_, _), vector: ListVector) =>
+        val elementVector = createFieldWriter(vector.getDataVector())
+        new ArrayWriter(vector, elementVector)
+      case (MapType(_, _, _), vector: MapVector) =>
+        val structVector = vector.getDataVector.asInstanceOf[StructVector]
+        val keyWriter = 
createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
+        val valueWriter = 
createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
+        new MapWriter(vector, structVector, keyWriter, valueWriter)
+      case (StructType(_), vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new StructWriter(vector, children.toArray)
+      case (NullType, vector: NullVector) => new NullWriter(vector)
+      case (_: YearMonthIntervalType, vector: IntervalYearVector) =>
+        new IntervalYearWriter(vector)
+      case (_: DayTimeIntervalType, vector: DurationVector) => new 
DurationWriter(vector)
+      case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
+        new IntervalMonthDayNanoWriter(vector)
+      case (dt, _) =>
+        throw QueryExecutionErrors.notSupportTypeError(dt)
+    }
+  }
+}
+
+class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) 
{
+
+  def schema: StructType = Utils.fromArrowSchema(root.getSchema())
+
+  private var count: Int = 0
+
+  def write(row: InternalRow): Unit = {
+    var i = 0
+    while (i < fields.length) {
+      fields(i).write(row, i)
+      i += 1
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    root.setRowCount(count)
+    fields.foreach(_.finish())
+  }
+
+  def reset(): Unit = {
+    root.setRowCount(0)
+    count = 0
+    fields.foreach(_.reset())
+  }
+}
+
+private[arrow] abstract class ArrowFieldWriter {
+
+  def valueVector: ValueVector
+
+  def name: String = valueVector.getField().getName()
+  def dataType: DataType = Utils.fromArrowField(valueVector.getField())
+  def nullable: Boolean = valueVector.getField().isNullable()
+
+  def setNull(): Unit
+  def setValue(input: SpecializedGetters, ordinal: Int): Unit
+
+  private[arrow] var count: Int = 0
+
+  def write(input: SpecializedGetters, ordinal: Int): Unit = {
+    if (input.isNullAt(ordinal)) {
+      setNull()
+    } else {
+      setValue(input, ordinal)
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    valueVector.setValueCount(count)
+  }
+
+  def reset(): Unit = {
+    valueVector.reset()
+    count = 0
+  }
+}
+
+private[arrow] class BooleanWriter(val valueVector: BitVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+  }
+}
+
+private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getByte(ordinal))
+  }
+}
+
+private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getShort(ordinal))
+  }
+}
+
+private[arrow] class IntegerWriter(val valueVector: IntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getInt(ordinal))
+  }
+}
+
+private[arrow] class LongWriter(val valueVector: BigIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class FloatWriter(val valueVector: Float4Vector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getFloat(ordinal))
+  }
+}
+
+private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getDouble(ordinal))
+  }
+}
+
+private[arrow] class DecimalWriter(val valueVector: DecimalVector, precision: 
Int, scale: Int)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val decimal = input.getDecimal(ordinal, precision, scale)
+    if (decimal.changePrecision(precision, scale)) {
+      valueVector.setSafe(count, decimal.toJavaBigDecimal)
+    } else {
+      setNull()
+    }
+  }
+}
+
+private[arrow] class StringWriter(val valueVector: VarCharVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    val utf8ByteBuffer = utf8.getByteBuffer
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+    valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), 
utf8.numBytes())
+  }
+}
+
+private[arrow] class LargeStringWriter(val valueVector: LargeVarCharVector)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    val utf8ByteBuffer = utf8.getByteBuffer
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?

Review Comment:
   Yea, it's a comment from Spark's original one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to