This is an automated email from the ASF dual-hosted git repository.

wangguangxin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new cc55a28787 [VL] Make PartialProject support struct with null fields 
(#10706)
cc55a28787 is described below

commit cc55a28787bfb033d925c4fe6d4447131e833967
Author: jiangjiangtian <[email protected]>
AuthorDate: Tue Oct 21 16:15:41 2025 +0800

    [VL] Make PartialProject support struct with null fields (#10706)
    
    * Make PartialProject support struct with null fields
    
    * fix compilation error
    
    * fix
    
    ---------
    
    Co-authored-by: 蒋添 <[email protected]>
---
 .../execution/ColumnarPartialGenerateExec.scala    | 10 ++-
 .../execution/ColumnarPartialProjectExec.scala     |  6 +-
 .../gluten/expression/UDFPartialProjectSuite.scala | 25 +++++++
 .../gluten/columnarbatch/ColumnarBatches.java      | 10 +++
 .../vectorized/ArrowWritableColumnVector.java      | 10 +++
 .../gluten/vectorized/ArrowColumnarBatch.scala     | 83 ++++++++++++++++++++++
 .../gluten/vectorized/ArrowColumnarRow.scala       | 48 +++++++------
 7 files changed, 166 insertions(+), 26 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
index ea1be35995..4e447df064 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
@@ -23,7 +23,7 @@ import 
org.apache.gluten.extension.columnar.transition.Convention
 import org.apache.gluten.iterator.Iterators
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.vectorized.{ArrowColumnarRow, 
ArrowWritableColumnVector}
+import org.apache.gluten.vectorized.{ArrowColumnarBatch, ArrowColumnarRow, 
ArrowWritableColumnVector}
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -179,12 +179,16 @@ case class ColumnarPartialGenerateExec(generateExec: 
GenerateExec, child: SparkP
     }
   }
 
-  private def loadArrowBatch(inputData: ColumnarBatch): ColumnarBatch = {
-    if (inputData.numCols() == 0) {
+  private def loadArrowBatch(inputData: ColumnarBatch): ArrowColumnarBatch = {
+    val sparkColumnarBatch = if (inputData.numCols() == 0) {
       inputData
     } else {
       ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), inputData)
     }
+    // In spark with version belows 4.0, the `ColumnarRow`'s get method 
doesn't check whether the
+    // column to get is null, so we change it to `ArrowColumnarBatch` 
manually. `ArrowColumnarBatch`
+    // returns `ArrowColumnarRow`, which fixes the bug.
+    ColumnarBatches.convertToArrowColumnarBatch(sparkColumnarBatch)
   }
 
   private def isVariableWidthType(dt: DataType): Boolean = dt match {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 232c535fda..8fa33d97a1 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -214,11 +214,15 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
     val proj = ArrowProjection.create(replacedAlias, projectAttributes.toSeq)
     val numRows = childData.numRows()
     val start = System.currentTimeMillis()
-    val arrowBatch = if (childData.numCols() == 0) {
+    val sparkColumnarBatch = if (childData.numCols() == 0) {
       childData
     } else {
       ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), childData)
     }
+    // In spark with version belows 4.0, the `ColumnarRow`'s get method 
doesn't check whether the
+    // column to get is null, so we change it to `ArrowColumnarBatch` 
manually. `ArrowColumnarBatch`
+    // returns `ArrowColumnarRow`, which fixes the bug.
+    val arrowBatch = 
ColumnarBatches.convertToArrowColumnarBatch(sparkColumnarBatch)
     c2a += System.currentTimeMillis() - start
 
     val schema =
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 5152cbc457..ab6d111214 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -28,6 +28,8 @@ import java.io.File
 
 case class MyStruct(a: Long, b: Array[Long])
 
+case class MyStructWithNullValue(a: Option[Long], b: Array[Long])
+
 class UDFPartialProjectSuiteRasOff extends UDFPartialProjectSuite {
   override protected def sparkConf: SparkConf = {
     super.sparkConf
@@ -247,4 +249,27 @@ abstract class UDFPartialProjectSuite extends 
WholeStageTransformerSuite {
         }
     }
   }
+
+  test("test struct data with null fields") {
+    spark.udf.register(
+      "struct_plus_one",
+      udf(
+        (m: MyStructWithNullValue) =>
+          MyStructWithNullValue(if (m.a.isEmpty) None else Some(m.a.get + 1), 
m.b.map(_ + 1))))
+    runQueryAndCompare("""
+                         |SELECT
+                         |  l_partkey,
+                         |  struct_plus_one(struct_data)
+                         |FROM (
+                         | SELECT l_partkey,
+                         | struct(
+                         |   CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey 
ELSE null END as a,
+                         |   array(l_orderkey % 2, l_orderkey % 2 + 1, 
l_orderkey % 2 + 2) as b
+                         | ) as struct_data
+                         | FROM lineitem
+                         |)
+                         |""".stripMargin) {
+      checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+    }
+  }
 }
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
index 156de4e0d8..01ceb7d20c 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
@@ -22,6 +22,7 @@ import org.apache.gluten.runtime.Runtimes;
 import org.apache.gluten.utils.ArrowAbiUtil;
 import org.apache.gluten.utils.ArrowUtil;
 import org.apache.gluten.utils.InternalRowUtl;
+import org.apache.gluten.vectorized.ArrowColumnarBatch;
 import org.apache.gluten.vectorized.ArrowWritableColumnVector;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -171,6 +172,15 @@ public final class ColumnarBatches {
     }
   }
 
+  public static ArrowColumnarBatch convertToArrowColumnarBatch(ColumnarBatch 
sparkColumnarBatch) {
+    int numCols = sparkColumnarBatch.numCols();
+    ArrowWritableColumnVector[] writableColumns = new 
ArrowWritableColumnVector[numCols];
+    for (int i = 0; i < numCols; i++) {
+      writableColumns[i] = (ArrowWritableColumnVector) 
sparkColumnarBatch.column(i);
+    }
+    return new ArrowColumnarBatch(writableColumns, 
sparkColumnarBatch.numRows());
+  }
+
   public static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch 
input) {
     if (isZeroColumnBatch(input)) {
       return input;
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
index 5491b19ca3..d00786f3f4 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
@@ -411,6 +411,16 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     return "vectorCounter is " + vectorCount.get();
   }
 
+  public ArrowColumnarRow getStructInternal(int rowId) {
+    if (isNullAt(rowId)) return null;
+    ArrowWritableColumnVector[] writableColumns =
+        new ArrowWritableColumnVector[childColumns.length];
+    for (int i = 0; i < writableColumns.length; i++) {
+      writableColumns[i] = (ArrowWritableColumnVector) childColumns[i];
+    }
+    return new ArrowColumnarRow(writableColumns, rowId);
+  }
+
   @Override
   public boolean hasNull() {
     return accessor.getNullCount() > 0;
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
new file mode 100644
index 0000000000..decd87e78f
--- /dev/null
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.vectorized.ColumnVector
+
+/**
+ * Because Spark-3.2 declares ColumnarBatch as final, so `ArrowColumnarBatch` 
can't extend
+ * `ColumnarBatch`. The code is mainly copied from Spark-3.2
+ *
+ * @param writableColumns
+ *   the columns this class wraps
+ * @param rowNumbers
+ *   the number of rows this batch contains
+ */
+class ArrowColumnarBatch(writableColumns: Array[ArrowWritableColumnVector], 
var rowNumbers: Int) {
+  private val arrowColumnarRow = new ArrowColumnarRow(writableColumns)
+
+  /**
+   * Called to close all the columns in this batch. It is not valid to access 
the data after calling
+   * this. This must be called at the end to clean up memory allocations.
+   */
+  def close(): Unit = {
+    for (c <- writableColumns) {
+      c.close()
+    }
+  }
+
+  /** Returns an iterator over the rows in this batch. */
+  def rowIterator: Iterator[InternalRow] = {
+    val maxRows = numRows
+    val row = new ArrowColumnarRow(writableColumns)
+    new Iterator[InternalRow]() {
+      var rowId = 0
+
+      override def hasNext: Boolean = rowId < maxRows
+
+      override def next: InternalRow = {
+        if (rowId >= maxRows) {
+          throw new NoSuchElementException()
+        }
+        row.rowId = rowId
+        rowId = rowId + 1
+        row
+      }
+    }
+  }
+
+  /** Sets the number of rows in this batch. */
+  def setNumRows(numRows: Int): Unit = {
+    this.rowNumbers = numRows
+  }
+
+  /** Returns the number of columns that make up this batch. */
+  def numCols: Int = writableColumns.length
+
+  /** Returns the number of rows for read, including filtered rows. */
+  def numRows: Int = this.rowNumbers
+
+  /** Returns the column at `ordinal`. */
+  def column(ordinal: Int): ColumnVector = writableColumns(ordinal)
+
+  def getRow(rowId: Int): InternalRow = {
+    assert(rowId >= 0 && rowId < this.numRows)
+    arrowColumnarRow.rowId = rowId
+    arrowColumnarRow
+  }
+}
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
index e5452e4ae5..f0e2c4dabf 100644
--- 
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
@@ -22,7 +22,7 @@ import 
org.apache.gluten.execution.InternalRowGetVariantCompatible
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap, 
ColumnarRow}
+import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap}
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 import java.math.BigDecimal
@@ -30,10 +30,9 @@ import java.math.BigDecimal
 // Copy from Spark MutableColumnarRow mostly but class member columns' type is
 // ArrowWritableColumnVector. And support string and binary type to write,
 // Arrow writer does not need to setNotNull before writing a value.
-final class ArrowColumnarRow(writableColumns: Array[ArrowWritableColumnVector])
+final class ArrowColumnarRow(writableColumns: 
Array[ArrowWritableColumnVector], var rowId: Int = 0)
   extends InternalRowGetVariantCompatible {
 
-  var rowId: Int = 0
   private val columns: Array[ArrowWritableColumnVector] = writableColumns
 
   override def numFields(): Int = columns.length
@@ -109,8 +108,8 @@ final class ArrowColumnarRow(writableColumns: 
Array[ArrowWritableColumnVector])
   override def getInterval(ordinal: Int): CalendarInterval =
     columns(ordinal).getInterval(rowId)
 
-  override def getStruct(ordinal: Int, numFields: Int): ColumnarRow =
-    columns(ordinal).getStruct(rowId)
+  override def getStruct(ordinal: Int, numFields: Int): ArrowColumnarRow =
+    columns(ordinal).getStructInternal(rowId)
 
   override def getArray(ordinal: Int): ColumnarArray =
     columns(ordinal).getArray(rowId)
@@ -118,23 +117,28 @@ final class ArrowColumnarRow(writableColumns: 
Array[ArrowWritableColumnVector])
   override def getMap(ordinal: Int): ColumnarMap =
     columns(ordinal).getMap(rowId)
 
-  override def get(ordinal: Int, dataType: DataType): AnyRef = dataType match {
-    case _: BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal))
-    case _: ByteType => java.lang.Byte.valueOf(getByte(ordinal))
-    case _: ShortType => java.lang.Short.valueOf(getShort(ordinal))
-    case _: IntegerType => java.lang.Integer.valueOf(getInt(ordinal))
-    case _: LongType => java.lang.Long.valueOf(getLong(ordinal))
-    case _: FloatType => java.lang.Float.valueOf(getFloat(ordinal))
-    case _: DoubleType => java.lang.Double.valueOf(getDouble(ordinal))
-    case _: StringType => getUTF8String(ordinal)
-    case _: BinaryType => getBinary(ordinal)
-    case t: DecimalType => getDecimal(ordinal, t.precision, t.scale)
-    case _: DateType => java.lang.Integer.valueOf(getInt(ordinal))
-    case _: TimestampType => java.lang.Long.valueOf(getLong(ordinal))
-    case _: ArrayType => getArray(ordinal)
-    case s: StructType => getStruct(ordinal, s.fields.length)
-    case _: MapType => getMap(ordinal)
-    case _ => throw new UnsupportedOperationException(s"Datatype not supported 
$dataType")
+  override def get(ordinal: Int, dataType: DataType): AnyRef = {
+    if (isNullAt(ordinal)) {
+      return null
+    }
+    dataType match {
+      case _: BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal))
+      case _: ByteType => java.lang.Byte.valueOf(getByte(ordinal))
+      case _: ShortType => java.lang.Short.valueOf(getShort(ordinal))
+      case _: IntegerType => java.lang.Integer.valueOf(getInt(ordinal))
+      case _: LongType => java.lang.Long.valueOf(getLong(ordinal))
+      case _: FloatType => java.lang.Float.valueOf(getFloat(ordinal))
+      case _: DoubleType => java.lang.Double.valueOf(getDouble(ordinal))
+      case _: StringType => getUTF8String(ordinal)
+      case _: BinaryType => getBinary(ordinal)
+      case t: DecimalType => getDecimal(ordinal, t.precision, t.scale)
+      case _: DateType => java.lang.Integer.valueOf(getInt(ordinal))
+      case _: TimestampType => java.lang.Long.valueOf(getLong(ordinal))
+      case _: ArrayType => getArray(ordinal)
+      case s: StructType => getStruct(ordinal, s.fields.length)
+      case _: MapType => getMap(ordinal)
+      case _ => throw new UnsupportedOperationException(s"Datatype not 
supported $dataType")
+    }
   }
 
   override def update(ordinal: Int, value: Any): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to