This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d33f383ba [GLUTEN-8668][VL] Support complex type in 
ColumnarPartialProject (#8669)
2d33f383ba is described below

commit 2d33f383ba9a3a9ea3d3326bd3b4c7a39d5a6b38
Author: WangGuangxin <[email protected]>
AuthorDate: Thu Feb 20 16:54:34 2025 +0800

    [GLUTEN-8668][VL] Support complex type in ColumnarPartialProject (#8669)
---
 .../execution/ColumnarPartialProjectExec.scala     |  34 +--
 .../java/org/apache/gluten/udf}/CustomerUDF.java   |   2 +-
 .../gluten/expression/UDFPartialProjectSuite.scala |  57 ++++++
 .../spark/sql/execution/GlutenHiveUDFSuite.scala   | 203 ++++++++++++++++++
 .../apache/gluten/vectorized/ArrowColumnarRow.java |  18 ++
 .../vectorized/ArrowWritableColumnVector.java      | 227 ++++++++++++++++++++-
 .../expression/InterpretedArrowProjection.scala    |  42 ++--
 .../sql/hive/execution/GlutenHiveUDFSuite.scala    | 139 -------------
 8 files changed, 521 insertions(+), 201 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 500c70fefa..9d6077ec62 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -29,12 +29,11 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow, 
ArrowWritableColumnVector
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, 
NamedExpression, NaNvl, ScalaUDF}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.hive.HiveUdfUtil
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
 
 import scala.collection.mutable.ListBuffer
@@ -59,10 +58,8 @@ case class ColumnarPartialProjectExec(original: ProjectExec, 
child: SparkPlan)(
   private val projectAttributes: ListBuffer[Attribute] = ListBuffer()
   private val projectIndexInChild: ListBuffer[Int] = ListBuffer()
   private var UDFAttrNotExists = false
-  private var hasUnsupportedDataType = replacedAliasUdf.exists(a => 
!validateDataType(a.dataType))
-  if (!hasUnsupportedDataType) {
-    getProjectIndexInChildOutput(replacedAliasUdf)
-  }
+  private var hasUnsupportedDataType = false
+  getProjectIndexInChildOutput(replacedAliasUdf)
 
   @transient override lazy val metrics = Map(
     "time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of 
partial project"),
@@ -102,26 +99,6 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
       .forall(validateExpression)
   }
 
-  private def validateDataType(dataType: DataType): Boolean = {
-    dataType match {
-      case _: BooleanType => true
-      case _: ByteType => true
-      case _: ShortType => true
-      case _: IntegerType => true
-      case _: LongType => true
-      case _: FloatType => true
-      case _: DoubleType => true
-      case _: StringType => true
-      case _: TimestampType => true
-      case _: DateType => true
-      case _: BinaryType => true
-      case _: DecimalType => true
-      case YearMonthIntervalType.DEFAULT => true
-      case _: NullType => true
-      case _ => false
-    }
-  }
-
   private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
     exprs.forall {
       case a: AttributeReference =>
@@ -131,7 +108,9 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
           UDFAttrNotExists = true
           log.debug(s"Expression $a should exist in child output 
${child.output}")
           false
-        } else if (!validateDataType(a.dataType)) {
+        } else if (
+          
BackendsApiManager.getValidatorApiInstance.doSchemaValidate(a.dataType).isDefined
+        ) {
           hasUnsupportedDataType = true
           log.debug(s"Expression $a contains unsupported data type 
${a.dataType}")
           false
@@ -253,6 +232,7 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
       targetRow.rowId = i
       proj.target(targetRow).apply(arrowBatch.getRow(i))
     }
+    targetRow.finishWriteRow()
     val targetBatch = new 
ColumnarBatch(vectors.map(_.asInstanceOf[ColumnVector]), numRows)
     val start2 = System.currentTimeMillis()
     val veloxBatch = VeloxColumnarBatches.toVeloxBatch(
diff --git 
a/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java 
b/backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
similarity index 97%
rename from 
gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
rename to backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
index 257bd07021..b677710dbc 100644
--- 
a/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
+++ b/backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.execution;
+package org.apache.gluten.udf;
 
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDF;
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 30b8cb0f69..3f1a68993f 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -24,6 +24,8 @@ import org.apache.spark.sql.functions.udf
 
 import java.io.File
 
+case class MyStruct(a: Long, b: Array[Long])
+
 class UDFPartialProjectSuiteRasOff extends UDFPartialProjectSuite {
   override protected def sparkConf: SparkConf = {
     super.sparkConf
@@ -147,4 +149,59 @@ abstract class UDFPartialProjectSuite extends 
WholeStageTransformerSuite {
     }
   }
 
+  test("udf with array") {
+    spark.udf.register("array_plus_one", udf((arr: Array[Int]) => arr.map(_ + 
1)))
+    runQueryAndCompare("""
+                         |SELECT
+                         |  l_partkey,
+                         |  sort_array(array_plus_one(array_data)) as 
orderkey_arr_plus_one
+                         |FROM (
+                         | SELECT l_partkey, collect_list(l_orderkey) as 
array_data
+                         | FROM lineitem
+                         | GROUP BY l_partkey
+                         |)
+                         |""".stripMargin) {
+      checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+    }
+  }
+
+  test("udf with map") {
+    spark.udf.register(
+      "map_value_plus_one",
+      udf((m: Map[String, Long]) => m.map { case (key, value) => key -> (value 
+ 1) }))
+    runQueryAndCompare("""
+                         |SELECT
+                         |  l_partkey,
+                         |  map_value_plus_one(map_data)
+                         |FROM (
+                         | SELECT l_partkey,
+                         | map(
+                         |   concat('hello', l_orderkey % 2), l_orderkey,
+                         |   concat('world', l_orderkey % 2), l_orderkey
+                         | ) as map_data
+                         | FROM lineitem
+                         |)
+                         |""".stripMargin) {
+      checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+    }
+  }
+
+  test("udf with struct and array") {
+    spark.udf.register("struct_plus_one", udf((m: MyStruct) => MyStruct(m.a + 
1, m.b.map(_ + 1))))
+    runQueryAndCompare("""
+                         |SELECT
+                         |  l_partkey,
+                         |  struct_plus_one(struct_data)
+                         |FROM (
+                         | SELECT l_partkey,
+                         | struct(
+                         |   l_orderkey % 2 as a,
+                         |   array(l_orderkey % 2, l_orderkey % 2 + 1, 
l_orderkey % 2 + 2) as b
+                         | ) as struct_data
+                         | FROM lineitem
+                         |)
+                         |""".stripMargin) {
+      checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+    }
+  }
 }
diff --git 
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
new file mode 100644
index 0000000000..c8404fb93b
--- /dev/null
+++ 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.gluten.execution.ColumnarPartialProjectExec
+import org.apache.gluten.udf.CustomerUDF
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config
+import org.apache.spark.internal.config.UI.UI_ENABLED
+import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
+import org.apache.spark.sql.hive.HiveUtils
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.test.SQLTestUtils
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+class GlutenHiveUDFSuite extends GlutenQueryTest with SQLTestUtils {
+  private var _spark: SparkSession = _
+
+  override protected def beforeAll(): Unit = {
+    super.beforeAll()
+
+    if (_spark == null) {
+      _spark = 
SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
+    }
+
+    _spark.sparkContext.setLogLevel("info")
+
+    createTestTable()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+  }
+
+  override protected def spark: SparkSession = _spark
+
+  protected def defaultSparkConf: SparkConf = {
+    val conf = new SparkConf()
+      .set("spark.master", "local[1]")
+      .set("spark.sql.test", "")
+      .set("spark.sql.testkey", "true")
+      .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+      .set(SQLConf.CODEGEN_FACTORY_MODE.key, 
CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+      .set(
+        HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+        "org.apache.spark.sql.hive.execution.PairSerDe")
+      // SPARK-8910
+      .set(UI_ENABLED, false)
+      .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+      // Hive changed the default of 
hive.metastore.disallow.incompatible.col.type.changes
+      // from false to true. For details, see the JIRA HIVE-12320 and 
HIVE-17764.
+      
.set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", 
"false")
+      // Disable ConvertToLocalRelation for better test coverage. Test cases 
built on
+      // LocalRelation will exercise the optimization rules better by 
disabling it as
+      // this rule may potentially block testing of other optimization rules 
such as
+      // ConstantPropagation etc.
+      .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, 
ConvertToLocalRelation.ruleName)
+
+    conf.set(
+      StaticSQLConf.WAREHOUSE_PATH,
+      conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
+  }
+
+  protected def sparkConf: SparkConf = {
+    defaultSparkConf
+      .set("spark.plugins", "org.apache.gluten.GlutenPlugin")
+      .set("spark.default.parallelism", "1")
+      .set("spark.memory.offHeap.enabled", "true")
+      .set("spark.memory.offHeap.size", "1024MB")
+      .set("spark.gluten.sql.native.writer.enabled", "true")
+  }
+
+  private def withTempFunction(funcName: String)(f: => Unit): Unit = {
+    try f
+    finally sql(s"DROP TEMPORARY FUNCTION IF EXISTS $funcName")
+  }
+
+  private def checkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag: 
ClassTag[T]): Unit = {
+    val executedPlan = getExecutedPlan(df)
+    assert(executedPlan.exists(plan => plan.getClass == tag.runtimeClass))
+  }
+
+  private def createTestTable(): Unit = {
+    val table = "lineitem"
+    val tableDir = getClass.getResource("/tpch-data-parquet").getFile
+    val tablePath = new File(tableDir, table).getAbsolutePath
+    val tableDF = spark.read.format("parquet").load(tablePath)
+    tableDF.createOrReplaceTempView(table)
+  }
+
+  test("customer udf") {
+    withTempFunction("testUDF") {
+      sql(s"CREATE TEMPORARY FUNCTION testUDF AS 
'${classOf[CustomerUDF].getName}'")
+      val df = sql("select l_partkey, testUDF(l_comment) from lineitem")
+      df.show()
+      checkOperatorMatch[ColumnarPartialProjectExec](df)
+    }
+  }
+
+  test("customer udf wrapped in function") {
+    withTempFunction("testUDF") {
+      sql(s"CREATE TEMPORARY FUNCTION testUDF AS 
'${classOf[CustomerUDF].getName}'")
+      val df = sql("select l_partkey, hash(testUDF(l_comment)) from lineitem")
+      df.show()
+      checkOperatorMatch[ColumnarPartialProjectExec](df)
+    }
+  }
+
+  test("example") {
+    withTempFunction("testUDF") {
+      sql("CREATE TEMPORARY FUNCTION testUDF AS 
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
+      val df = sql("select testUDF('l_commen', 1, 5)")
+      df.show()
+      // It should not be converted to ColumnarPartialProjectExec, since
+      // the UDF need all the columns in child output.
+      assert(!getExecutedPlan(df).exists {
+        case _: ColumnarPartialProjectExec => true
+        case _ => false
+      })
+    }
+  }
+
+  test("udf with array") {
+    withTempFunction("udf_sort_array") {
+      sql("""
+            |CREATE TEMPORARY FUNCTION udf_sort_array AS
+            |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFSortArray';
+            |""".stripMargin)
+
+      val df = sql("""
+                     |SELECT
+                     |  l_orderkey,
+                     |  l_partkey,
+                     |  udf_sort_array(array(10, l_orderkey, 1)) as udf_result
+                     |FROM lineitem WHERE l_partkey <= 5 and l_orderkey <1000
+                     |""".stripMargin)
+
+      checkAnswer(
+        df,
+        Seq(
+          Row(35, 5, mutable.WrappedArray.make(Array(1, 10, 35))),
+          Row(321, 4, mutable.WrappedArray.make(Array(1, 10, 321))),
+          Row(548, 2, mutable.WrappedArray.make(Array(1, 10, 548))),
+          Row(640, 5, mutable.WrappedArray.make(Array(1, 10, 640))),
+          Row(807, 2, mutable.WrappedArray.make(Array(1, 10, 807)))
+        )
+      )
+      checkOperatorMatch[ColumnarPartialProjectExec](df)
+    }
+  }
+
+  test("udf with map") {
+    withTempFunction("udf_str_to_map") {
+      sql("""
+            |CREATE TEMPORARY FUNCTION udf_str_to_map AS
+            |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFStringToMap';
+            |""".stripMargin)
+
+      val df = sql(
+        """
+          |SELECT
+          |  l_orderkey,
+          |  l_partkey,
+          |  udf_str_to_map(
+          |    concat_ws(',', array(concat('hello', l_partkey), 'world')), 
',', 'l') as udf_result
+          |FROM lineitem WHERE l_partkey <= 5 and l_orderkey <1000
+          |""".stripMargin)
+
+      checkAnswer(
+        df,
+        Seq(
+          Row(321, 4, Map("he" -> "lo4", "wor" -> "d")),
+          Row(35, 5, Map("he" -> "lo5", "wor" -> "d")),
+          Row(548, 2, Map("he" -> "lo2", "wor" -> "d")),
+          Row(640, 5, Map("he" -> "lo5", "wor" -> "d")),
+          Row(807, 2, Map("he" -> "lo2", "wor" -> "d"))
+        )
+      )
+      checkOperatorMatch[ColumnarPartialProjectExec](df)
+    }
+  }
+}
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
index e0ac5858e7..907ce1ca73 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
@@ -16,6 +16,8 @@
  */
 package org.apache.gluten.vectorized;
 
+import org.apache.gluten.exception.GlutenException;
+
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
 import org.apache.spark.sql.types.ArrayType;
@@ -316,4 +318,20 @@ public final class ArrowColumnarRow extends InternalRow {
   public void setBinary(int ordinal, byte[] value) {
     columns[ordinal].putBytes(rowId, value.length, value, 0);
   }
+
+  public void writeRow(GenericInternalRow input) {
+    if (input.numFields() != columns.length) {
+      throw new GlutenException(
+          "The numFields of input row should be equal to the number of column 
vector!");
+    }
+    for (int i = 0; i < input.numFields(); ++i) {
+      columns[i].write(input, i);
+    }
+  }
+
+  public void finishWriteRow() {
+    for (int i = 0; i < columns.length; ++i) {
+      columns[i].finishWrite();
+    }
+  }
 }
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
index 336d33771b..0d74b7d4ac 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
@@ -46,13 +46,14 @@ import 
org.apache.arrow.vector.holders.NullableVarCharHolder;
 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
+import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
+import org.apache.spark.sql.catalyst.util.MapData;
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
 import org.apache.spark.sql.execution.vectorized.WritableColumnVectorShim;
-import org.apache.spark.sql.types.ArrayType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.*;
 import org.apache.spark.sql.utils.SparkArrowUtil;
 import org.apache.spark.sql.utils.SparkSchemaUtil;
 import org.apache.spark.sql.vectorized.ColumnarArray;
@@ -307,6 +308,12 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
       return new DateWriter((DateDayVector) vector);
     } else if (vector instanceof TimeStampMicroVector || vector instanceof 
TimeStampMicroTZVector) {
       return new TimestampMicroWriter((TimeStampVector) vector);
+    } else if (vector instanceof MapVector) {
+      MapVector mapVector = (MapVector) vector;
+      StructVector entries = (StructVector) mapVector.getDataVector();
+      ArrowVectorWriter keyWriter = 
createVectorWriter(entries.getChild(MapVector.KEY_NAME));
+      ArrowVectorWriter valueWriter = 
createVectorWriter(entries.getChild(MapVector.VALUE_NAME));
+      return new MapWriter(mapVector, entries, keyWriter, valueWriter);
     } else if (vector instanceof ListVector) {
       ListVector listVector = (ListVector) vector;
       ArrowVectorWriter elementVector = 
createVectorWriter(listVector.getDataVector());
@@ -742,6 +749,14 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     return accessor.getBinary(rowId);
   }
 
+  public void write(SpecializedGetters input, int ordinal) {
+    writer.write(input, ordinal);
+  }
+
+  public void finishWrite() {
+    writer.finish();
+  }
+
   private abstract static class ArrowVectorAccessor {
     private final ValueVector vector;
 
@@ -1394,6 +1409,28 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     void setBytes(int rowId, BigDecimal value) {
       throw new UnsupportedOperationException();
     }
+
+    protected int count = 0;
+
+    abstract void setValueNullSafe(SpecializedGetters input, int ordinal);
+
+    void write(SpecializedGetters input, int ordinal) {
+      if (input.isNullAt(ordinal)) {
+        setNull(count);
+      } else {
+        setValueNullSafe(input, ordinal);
+      }
+      count = count + 1;
+    }
+
+    void finish() {
+      vector.setValueCount(count);
+    }
+
+    void reset() {
+      vector.reset();
+      count = 0;
+    }
   }
 
   private static class BooleanWriter extends ArrowVectorWriter {
@@ -1427,6 +1464,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, value ? 1 : 0);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      this.setBoolean(count, input.getBoolean(ordinal));
+    }
   }
 
   private static class ByteWriter extends ArrowVectorWriter {
@@ -1467,6 +1509,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, src[srcIndex + i]);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      this.setByte(count, input.getByte(ordinal));
+    }
   }
 
   private static class ShortWriter extends ArrowVectorWriter {
@@ -1514,6 +1561,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, src[srcIndex + i]);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      this.setShort(count, input.getShort(ordinal));
+    }
   }
 
   private static class IntWriter extends ArrowVectorWriter {
@@ -1574,6 +1626,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, tmp);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setInt(count, input.getInt(ordinal));
+    }
   }
 
   private static class LongWriter extends ArrowVectorWriter {
@@ -1629,6 +1686,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
       writer.setSafe(rowId, val);
     }
 
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setLong(count, input.getLong(ordinal));
+    }
+
     @Override
     void setLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
       int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
@@ -1680,6 +1742,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, src[srcIndex + i]);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setFloat(count, input.getFloat(ordinal));
+    }
   }
 
   private static class DoubleWriter extends ArrowVectorWriter {
@@ -1720,14 +1787,26 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setSafe(rowId + i, src[srcIndex + i]);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setDouble(count, input.getDouble(ordinal));
+    }
   }
 
   private static class DecimalWriter extends ArrowVectorWriter {
     private final DecimalVector writer;
+    private final int precision;
+    private final int scale;
 
     DecimalWriter(DecimalVector vector) {
       super(vector);
       this.writer = vector;
+
+      DataType dataType = SparkArrowUtil.fromArrowField(vector.getField());
+      DecimalType decimalType = (DecimalType) dataType;
+      this.precision = decimalType.precision();
+      this.scale = decimalType.scale();
     }
 
     @Override
@@ -1757,6 +1836,16 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
       }
       throw new UnsupportedOperationException();
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      Decimal decimal = input.getDecimal(ordinal, precision, scale);
+      if (decimal.changePrecision(precision, scale)) {
+        setBytes(count, decimal.toJavaBigDecimal());
+      } else {
+        setNull(count);
+      }
+    }
   }
 
   private static class StringWriter extends ArrowVectorWriter {
@@ -1791,6 +1880,12 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
       writer.setSafe(rowId, value, offset, length);
       rowId++;
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      UTF8String value = input.getUTF8String(ordinal);
+      setBytes(count, value.numBytes(), value.getBytes(), 0);
+    }
   }
 
   private static class BinaryWriter extends ArrowVectorWriter {
@@ -1817,6 +1912,12 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     final void setBytes(int rowId, int count, byte[] src, int srcIndex) {
       writer.setSafe(rowId, src, srcIndex, count);
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      byte[] value = input.getBinary(ordinal);
+      setBytes(count, value.length, value, 0);
+    }
   }
 
   private static class DateWriter extends ArrowVectorWriter {
@@ -1850,6 +1951,11 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setNull(rowId + i);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setInt(count, input.getInt(ordinal));
+    }
   }
 
   private static class TimestampMicroWriter extends ArrowVectorWriter {
@@ -1883,14 +1989,21 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
         writer.setNull(rowId + i);
       }
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setLong(count, input.getLong(ordinal));
+    }
   }
 
   private static class ArrayWriter extends ArrowVectorWriter {
     private final ListVector writer;
+    private final ArrowVectorWriter elementWriter;
 
     ArrayWriter(ListVector vector, ArrowVectorWriter elementVector) {
       super(vector);
       this.writer = vector;
+      this.elementWriter = elementVector;
     }
 
     @Override
@@ -1905,18 +2018,46 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     final void setNull(int rowId) {
       writer.setNull(rowId);
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      ArrayData arrayData = input.getArray(ordinal);
+      writer.startNewValue(count);
+      for (int i = 0; i < arrayData.numElements(); ++i) {
+        elementWriter.write(arrayData, i);
+      }
+      writer.endValue(count, arrayData.numElements());
+    }
+
+    @Override
+    void finish() {
+      super.finish();
+      elementWriter.finish();
+    }
+
+    @Override
+    void reset() {
+      super.reset();
+      elementWriter.reset();
+    }
   }
 
   private static class StructWriter extends ArrowVectorWriter {
     private final StructVector writer;
+    private final ArrowVectorWriter[] childrenWriter;
 
-    StructWriter(StructVector vector, ArrowVectorWriter[] children) {
+    StructWriter(StructVector vector, ArrowVectorWriter[] childrenWriter) {
       super(vector);
       this.writer = vector;
+      this.childrenWriter = childrenWriter;
     }
 
     @Override
     void setNull(int rowId) {
+      for (int i = 0; i < childrenWriter.length; ++i) {
+        childrenWriter[i].setNull(rowId);
+        childrenWriter[i].count += 1;
+      }
       writer.setNull(rowId);
     }
 
@@ -1924,6 +2065,77 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     void setNotNull(int rowId) {
       writer.setIndexDefined(rowId);
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      InternalRow struct = input.getStruct(ordinal, childrenWriter.length);
+      writer.setIndexDefined(count);
+      for (int i = 0; i < struct.numFields(); ++i) {
+        childrenWriter[i].write(struct, i);
+      }
+    }
+
+    @Override
+    void finish() {
+      super.finish();
+      Arrays.stream(childrenWriter).forEach(c -> c.finish());
+    }
+
+    @Override
+    void reset() {
+      super.reset();
+      Arrays.stream(childrenWriter).forEach(c -> c.reset());
+    }
+  }
+
+  private static class MapWriter extends ArrowVectorWriter {
+    private final MapVector writer;
+    private StructVector structVector;
+    private final ArrowVectorWriter keyWriter;
+    private final ArrowVectorWriter valueWriter;
+
+    MapWriter(
+        MapVector mapVector,
+        StructVector structVector,
+        ArrowVectorWriter mapWriter,
+        ArrowVectorWriter valueWriter) {
+      super(mapVector);
+      this.writer = mapVector;
+      this.structVector = structVector;
+      this.keyWriter = mapWriter;
+      this.valueWriter = valueWriter;
+    }
+
+    @Override
+    void setNull(int rowId) {}
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      MapData mapData = input.getMap(ordinal);
+      writer.startNewValue(count);
+      ArrayData keys = mapData.keyArray();
+      ArrayData values = mapData.valueArray();
+      for (int i = 0; i < mapData.numElements(); ++i) {
+        structVector.setIndexDefined(i);
+        keyWriter.write(keys, i);
+        valueWriter.write(values, i);
+      }
+      writer.endValue(count, mapData.numElements());
+    }
+
+    @Override
+    void finish() {
+      super.finish();
+      keyWriter.finish();
+      valueWriter.finish();
+    }
+
+    @Override
+    void reset() {
+      super.reset();
+      keyWriter.reset();
+      valueWriter.reset();
+    }
   }
 
   private static class NullWriter extends ArrowVectorWriter {
@@ -1938,5 +2150,10 @@ public final class ArrowWritableColumnVector extends 
WritableColumnVectorShim {
     void setNull(int rowId) {
       writer.setValueCount(writer.getValueCount() + 1);
     }
+
+    @Override
+    void setValueNullSafe(SpecializedGetters input, int ordinal) {
+      setNull(count);
+    }
   }
 }
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
index 33a3802f88..7d073f3b20 100644
--- 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
@@ -19,12 +19,10 @@ package org.apache.gluten.expression
 import org.apache.gluten.vectorized.ArrowColumnarRow
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
GenericInternalRow}
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * A [[ArrowProjection]] that is calculated by calling `eval` on each of the 
specified expressions.
@@ -57,43 +55,29 @@ class InterpretedArrowProjection(expressions: 
Seq[Expression]) extends ArrowProj
     this
   }
 
-  private def getStringWriter(ordinal: Int, dt: DataType): (ArrowColumnarRow, 
Any) => Unit =
-    dt match {
-      case StringType => (input, v) => input.setUTF8String(ordinal, 
v.asInstanceOf[UTF8String])
-      case BinaryType => (input, v) => input.setBinary(ordinal, 
v.asInstanceOf[Array[Byte]])
-      case _ => (input, v) => input.update(ordinal, v)
-    }
+  /** Number of (top level) fields in the resulting row. */
+  private[this] val numFields = validExprs.length
+
+  /** Array that expression results. */
+  private[this] val values = new Array[Any](numFields)
 
-  private[this] val fieldWriters: Array[Any => Unit] = validExprs.map {
-    case (e, i) =>
-      val writer = if (e.dataType.isInstanceOf[StringType] || 
e.dataType.isInstanceOf[BinaryType]) {
-        getStringWriter(i, e.dataType)
-      } else InternalRow.getWriter(i, e.dataType)
-      if (!e.nullable) { (v: Any) => writer(mutableRow, v) }
-      else {
-        (v: Any) =>
-          {
-            if (v == null) {
-              mutableRow.setNullAt(i)
-            } else {
-              writer(mutableRow, v)
-            }
-          }
-      }
-  }.toArray
+  /** The row representing the expression results. */
+  private[this] val intermediate = new GenericInternalRow(values)
 
   override def apply(input: InternalRow): ArrowColumnarRow = {
     if (subExprEliminationEnabled) {
       runtime.setInput(input)
     }
 
+    // Put the expression results in the intermediate row.
     var i = 0
-    while (i < validExprs.length) {
+    while (i < numFields) {
       val (_, ordinal) = validExprs(i)
-      // Store the result into buffer first, to make the projection atomic 
(needed by aggregation)
-      fieldWriters(i)(exprs(ordinal).eval(input))
+      values(i) = exprs(ordinal).eval(input)
       i += 1
     }
+
+    mutableRow.writeRow(intermediate)
     mutableRow
   }
 }
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
deleted file mode 100644
index 5ee5cdf202..0000000000
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.hive.execution
-
-import org.apache.gluten.execution.CustomerUDF
-
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.internal.config
-import org.apache.spark.internal.config.UI.UI_ENABLED
-import org.apache.spark.sql.{GlutenTestsBaseTrait, QueryTest, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
-import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
-import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.hive.test.TestHiveContext
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
-import org.apache.spark.sql.test.SQLTestUtils
-
-import org.scalatest.BeforeAndAfterAll
-
-import java.io.File
-
-trait GlutenTestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
-  override protected val enableAutoThreadAudit = false
-
-}
-
-object GlutenTestHive
-  extends TestHiveContext(
-    new SparkContext(
-      System.getProperty("spark.sql.test.master", "local[1]"),
-      "TestSQLContext",
-      new SparkConf()
-        .set("spark.sql.test", "")
-        .set(SQLConf.CODEGEN_FALLBACK.key, "false")
-        .set(SQLConf.CODEGEN_FACTORY_MODE.key, 
CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
-        .set(
-          HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
-          "org.apache.spark.sql.hive.execution.PairSerDe")
-        .set(WAREHOUSE_PATH.key, 
TestHiveContext.makeWarehouseDir().toURI.getPath)
-        // SPARK-8910
-        .set(UI_ENABLED, false)
-        .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
-        // Hive changed the default of 
hive.metastore.disallow.incompatible.col.type.changes
-        // from false to true. For details, see the JIRA HIVE-12320 and 
HIVE-17764.
-        
.set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes", 
"false")
-        .set("spark.driver.memory", "1G")
-        .set("spark.sql.adaptive.enabled", "true")
-        .set("spark.sql.shuffle.partitions", "1")
-        .set("spark.sql.files.maxPartitionBytes", "134217728")
-        .set("spark.memory.offHeap.enabled", "true")
-        .set("spark.memory.offHeap.size", "1024MB")
-        .set("spark.plugins", "org.apache.gluten.GlutenPlugin")
-        .set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.ColumnarShuffleManager")
-        // Disable ConvertToLocalRelation for better test coverage. Test cases 
built on
-        // LocalRelation will exercise the optimization rules better by 
disabling it as
-        // this rule may potentially block testing of other optimization rules 
such as
-        // ConstantPropagation etc.
-        .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, 
ConvertToLocalRelation.ruleName)
-    ),
-    false
-  ) {}
-
-class GlutenHiveUDFSuite
-  extends QueryTest
-  with GlutenTestHiveSingleton
-  with SQLTestUtils
-  with GlutenTestsBaseTrait {
-  override protected lazy val spark: SparkSession = GlutenTestHive.sparkSession
-  protected lazy val hiveContext: TestHiveContext = GlutenTestHive
-  protected lazy val hiveClient: HiveClient =
-    
spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client
-
-  override protected def beforeAll(): Unit = {
-    super.beforeAll()
-    val table = "lineitem"
-    val tableDir =
-      getClass.getResource("").getPath + "/../../../../../../../../../../../" +
-        "/backends-velox/src/test/resources/tpch-data-parquet/"
-    val tablePath = new File(tableDir, table).getAbsolutePath
-    val tableDF = spark.read.format("parquet").load(tablePath)
-    tableDF.createOrReplaceTempView(table)
-  }
-
-  override protected def afterAll(): Unit = {
-    try {
-      hiveContext.reset()
-    } finally {
-      super.afterAll()
-    }
-  }
-
-  override protected def shouldRun(testName: String): Boolean = {
-    false
-  }
-
-  test("customer udf") {
-    sql(s"CREATE TEMPORARY FUNCTION testUDF AS 
'${classOf[CustomerUDF].getName}'")
-    val df = spark.sql("""select testUDF(l_comment)
-                         | from   lineitem""".stripMargin)
-    df.show()
-    print(df.queryExecution.executedPlan)
-    sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
-    hiveContext.reset()
-  }
-
-  test("customer udf wrapped in function") {
-    sql(s"CREATE TEMPORARY FUNCTION testUDF AS 
'${classOf[CustomerUDF].getName}'")
-    val df = spark.sql("""select hash(testUDF(l_comment))
-                         | from   lineitem""".stripMargin)
-    df.show()
-    print(df.queryExecution.executedPlan)
-    sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
-    hiveContext.reset()
-  }
-
-  test("example") {
-    spark.sql("CREATE TEMPORARY FUNCTION testUDF AS 
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
-    spark.sql("select testUDF('l_commen', 1, 5)").show()
-    sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
-    hiveContext.reset()
-  }
-
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to