This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 932d8a261f [GLUTEN-7359][VL] Optimize string in partial project (#7592)
932d8a261f is described below

commit 932d8a261fa0e0537bd191dd05bc3b7988d31acc
Author: Jin Chengcheng <[email protected]>
AuthorDate: Mon Oct 28 09:32:40 2024 +0800

    [GLUTEN-7359][VL] Optimize string in partial project (#7592)
---
 .../execution/ColumnarPartialProjectExec.scala     |  95 ++----
 .../gluten/expression/UDFPartialProjectSuite.scala |   8 +
 .../apache/gluten/vectorized/ArrowColumnarRow.java | 319 +++++++++++++++++++++
 .../apache/gluten/expression/ArrowProjection.scala |  67 +++++
 .../expression/InterpretedArrowProjection.scala    | 108 +++++++
 .../expressions/ExpressionsEvaluator.scala         |  50 ++++
 .../expressions/ExpressionsEvaluator.scala         |  50 ++++
 7 files changed, 626 insertions(+), 71 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 1c394103db..d993e399db 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -18,19 +18,18 @@ package org.apache.gluten.execution
 
 import org.apache.gluten.GlutenConfig
 import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
-import org.apache.gluten.expression.ExpressionUtils
+import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils}
 import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
 import org.apache.gluten.iterator.Iterators
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.vectorized.ArrowWritableColumnVector
+import org.apache.gluten.vectorized.{ArrowColumnarRow, 
ArrowWritableColumnVector}
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, 
MutableProjection, NamedExpression, NaNvl, ScalaUDF, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, 
NamedExpression, NaNvl, ScalaUDF}
 import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, 
WritableColumnVector}
 import org.apache.spark.sql.hive.HiveUdfUtil
 import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, 
DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, 
NullType, ShortType, StringType, TimestampType, YearMonthIntervalType}
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -38,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, 
ColumnVector}
 import scala.collection.mutable.ListBuffer
 
 /**
- * By rule <PartialProhectRule>, the project not offload-able that is changed 
to
+ * By rule <PartialProjectRule>, the project not offload-able that is changed 
to
  * ProjectExecTransformer + ColumnarPartialProjectExec e.g. sum(myudf(a) + b + 
hash(c)), child is
  * (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as 
_SparkPartialProject1),
  * ProjectExecTransformer(_SparkPartialProject1 + b + hash(c))
@@ -64,12 +63,12 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
 
   @transient override lazy val metrics = Map(
     "time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of 
partial project"),
-    "column_to_row_time" -> SQLMetrics.createTimingMetric(
+    "velox_to_arrow_time" -> SQLMetrics.createTimingMetric(
       sparkContext,
-      "time of velox to Arrow ColumnarBatch or UnsafeRow"),
-    "row_to_column_time" -> SQLMetrics.createTimingMetric(
+      "time of velox to Arrow ColumnarBatch"),
+    "arrow_to_velox_time" -> SQLMetrics.createTimingMetric(
       sparkContext,
-      "time of Arrow ColumnarBatch or UnsafeRow to velox")
+      "time of Arrow ColumnarBatch to velox")
   )
 
   override def output: Seq[Attribute] = child.output ++ 
replacedAliasUdf.map(_.toAttribute)
@@ -111,22 +110,26 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
   }
 
   private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
-    exprs.foreach {
+    exprs.forall {
       case a: AttributeReference =>
         val index = child.output.indexWhere(s => s.exprId.equals(a.exprId))
         // Some child operator as HashAggregateTransformer will not have udf 
child column
         if (index < 0) {
           UDFAttrNotExists = true
           log.debug(s"Expression $a should exist in child output 
${child.output}")
-          return
+          false
         } else if (!validateDataType(a.dataType)) {
           hasUnsupportedDataType = true
           log.debug(s"Expression $a contains unsupported data type 
${a.dataType}")
+          false
         } else if (!projectIndexInChild.contains(index)) {
           projectAttributes.append(a.toAttribute)
           projectIndexInChild.append(index)
-        }
-      case p => getProjectIndexInChildOutput(p.children)
+          true
+        } else true
+      case p =>
+        getProjectIndexInChildOutput(p.children)
+        true
     }
   }
 
@@ -150,7 +153,7 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
       return ValidationResult.failed("No UDF")
     }
     if (replacedAliasUdf.size > original.output.size) {
-      // e.g. udf1(col) + udf2(col), it will introduce 2 cols for r2c
+      // e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c
       return ValidationResult.failed("Number of RowToColumn columns is more 
than ProjectExec")
     }
     if (!original.projectList.forall(validateExpression(_))) {
@@ -168,9 +171,8 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
 
   override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
     val totalTime = longMetric("time")
-    val c2r = longMetric("column_to_row_time")
-    val r2c = longMetric("row_to_column_time")
-    val isMutable = canUseMutableProjection()
+    val c2a = longMetric("velox_to_arrow_time")
+    val a2c = longMetric("arrow_to_velox_time")
     child.executeColumnar().mapPartitions {
       batches =>
         val res: Iterator[Iterator[ColumnarBatch]] = new 
Iterator[Iterator[ColumnarBatch]] {
@@ -183,9 +185,8 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
             } else {
               val start = System.currentTimeMillis()
               val childData = ColumnarBatches.select(batch, 
projectIndexInChild.toArray)
-              val projectedBatch = if (isMutable) {
-                getProjectedBatchArrow(childData, c2r, r2c)
-              } else getProjectedBatch(childData, c2r, r2c)
+              val projectedBatch = getProjectedBatchArrow(childData, c2a, a2c)
+
               val batchIterator = projectedBatch.map {
                 b =>
                   if (b.numCols() != 0) {
@@ -214,60 +215,12 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
     }
   }
 
-  // scalastyle:off line.size.limit
-  // String type cannot use MutableProjection
-  // Otherwise will throw java.lang.UnsupportedOperationException: Datatype 
not supported StringType
-  // at 
org.apache.spark.sql.execution.vectorized.MutableColumnarRow.update(MutableColumnarRow.java:224)
-  // at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
 Source)
-  // scalastyle:on line.size.limit
-  private def canUseMutableProjection(): Boolean = {
-    replacedAliasUdf.forall(
-      r =>
-        r.dataType match {
-          case StringType | BinaryType => false
-          case _ => true
-        })
-  }
-
-  /**
-   * add c2r and r2c for unsupported expression child data c2r get 
Iterator[InternalRow], then call
-   * Spark project, then r2c
-   */
-  private def getProjectedBatch(
-      childData: ColumnarBatch,
-      c2r: SQLMetric,
-      r2c: SQLMetric): Iterator[ColumnarBatch] = {
-    // select part of child output and child data
-    val proj = UnsafeProjection.create(replacedAliasUdf, 
projectAttributes.toSeq)
-    val numOutputRows = new SQLMetric("numOutputRows")
-    val numInputBatches = new SQLMetric("numInputBatches")
-    val rows = VeloxColumnarToRowExec
-      .toRowIterator(
-        Iterator.single[ColumnarBatch](childData),
-        projectAttributes.toSeq,
-        numOutputRows,
-        numInputBatches,
-        c2r)
-      .map(proj)
-
-    val schema =
-      
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
-    RowToVeloxColumnarExec.toColumnarBatchIterator(
-      rows,
-      schema,
-      numOutputRows,
-      numInputBatches,
-      r2c,
-      childData.numRows())
-    // TODO: should check the size <= 1, but now it has bug, will change 
iterator to empty
-  }
-
   private def getProjectedBatchArrow(
       childData: ColumnarBatch,
       c2a: SQLMetric,
       a2c: SQLMetric): Iterator[ColumnarBatch] = {
     // select part of child output and child data
-    val proj = MutableProjection.create(replacedAliasUdf, 
projectAttributes.toSeq)
+    val proj = ArrowProjection.create(replacedAliasUdf, 
projectAttributes.toSeq)
     val numRows = childData.numRows()
     val start = System.currentTimeMillis()
     val arrowBatch = if (childData.numCols() == 0) {
@@ -279,14 +232,14 @@ case class ColumnarPartialProjectExec(original: 
ProjectExec, child: SparkPlan)(
 
     val schema =
       
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
-    val vectors: Array[WritableColumnVector] = ArrowWritableColumnVector
+    val vectors: Array[ArrowWritableColumnVector] = ArrowWritableColumnVector
       .allocateColumns(numRows, schema)
       .map {
         vector =>
           vector.setValueCount(numRows)
           vector
       }
-    val targetRow = new MutableColumnarRow(vectors)
+    val targetRow = new ArrowColumnarRow(vectors)
     for (i <- 0 until numRows) {
       targetRow.rowId = i
       proj.target(targetRow).apply(arrowBatch.getRow(i))
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 757d4da131..4aaa722c68 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -70,6 +70,8 @@ abstract class UDFPartialProjectSuite extends 
WholeStageTransformerSuite {
     spark.udf.register("plus_one", plusOne)
     val noArgument = udf(() => 15)
     spark.udf.register("no_argument", noArgument)
+    val concat = udf((x: String) => x + "_concat")
+    spark.udf.register("concat_concat", concat)
 
   }
 
@@ -139,4 +141,10 @@ abstract class UDFPartialProjectSuite extends 
WholeStageTransformerSuite {
     }
   }
 
+  test("test concat with string") {
+    runQueryAndCompare("SELECT concat_concat(l_comment), hash(l_partkey) from 
lineitem") {
+      checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+    }
+  }
+
 }
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
new file mode 100644
index 0000000000..e0ac5858e7
--- /dev/null
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.ByteType;
+import org.apache.spark.sql.types.CalendarIntervalType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.math.BigDecimal;
+
+// Copy from Spark MutableColumnarRow mostly but class member columns`type is
+// ArrowWritableColumnVector. And support string and binary type to write, 
Arrow writer does not
+// need to setNotNull before write a value.
+public final class ArrowColumnarRow extends InternalRow {
+  public int rowId;
+  private final ArrowWritableColumnVector[] columns;
+
+  public ArrowColumnarRow(ArrowWritableColumnVector[] writableColumns) {
+    this.columns = writableColumns;
+  }
+
+  @Override
+  public int numFields() {
+    return columns.length;
+  }
+
+  @Override
+  public InternalRow copy() {
+    GenericInternalRow row = new GenericInternalRow(columns.length);
+    for (int i = 0; i < numFields(); i++) {
+      if (isNullAt(i)) {
+        row.setNullAt(i);
+      } else {
+        DataType dt = columns[i].dataType();
+        if (dt instanceof BooleanType) {
+          row.setBoolean(i, getBoolean(i));
+        } else if (dt instanceof ByteType) {
+          row.setByte(i, getByte(i));
+        } else if (dt instanceof ShortType) {
+          row.setShort(i, getShort(i));
+        } else if (dt instanceof IntegerType) {
+          row.setInt(i, getInt(i));
+        } else if (dt instanceof LongType) {
+          row.setLong(i, getLong(i));
+        } else if (dt instanceof FloatType) {
+          row.setFloat(i, getFloat(i));
+        } else if (dt instanceof DoubleType) {
+          row.setDouble(i, getDouble(i));
+        } else if (dt instanceof StringType) {
+          row.update(i, getUTF8String(i).copy());
+        } else if (dt instanceof BinaryType) {
+          row.update(i, getBinary(i));
+        } else if (dt instanceof DecimalType) {
+          DecimalType t = (DecimalType) dt;
+          row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), 
t.precision());
+        } else if (dt instanceof DateType) {
+          row.setInt(i, getInt(i));
+        } else if (dt instanceof TimestampType) {
+          row.setLong(i, getLong(i));
+        } else if (dt instanceof StructType) {
+          row.update(i, getStruct(i, ((StructType) 
dt).fields().length).copy());
+        } else if (dt instanceof ArrayType) {
+          row.update(i, getArray(i).copy());
+        } else if (dt instanceof MapType) {
+          row.update(i, getMap(i).copy());
+        } else {
+          throw new RuntimeException("Not implemented. " + dt);
+        }
+      }
+    }
+    return row;
+  }
+
+  @Override
+  public boolean anyNull() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isNullAt(int ordinal) {
+    return columns[ordinal].isNullAt(rowId);
+  }
+
+  @Override
+  public boolean getBoolean(int ordinal) {
+    return columns[ordinal].getBoolean(rowId);
+  }
+
+  @Override
+  public byte getByte(int ordinal) {
+    return columns[ordinal].getByte(rowId);
+  }
+
+  @Override
+  public short getShort(int ordinal) {
+    return columns[ordinal].getShort(rowId);
+  }
+
+  @Override
+  public int getInt(int ordinal) {
+    return columns[ordinal].getInt(rowId);
+  }
+
+  @Override
+  public long getLong(int ordinal) {
+    return columns[ordinal].getLong(rowId);
+  }
+
+  @Override
+  public float getFloat(int ordinal) {
+    return columns[ordinal].getFloat(rowId);
+  }
+
+  @Override
+  public double getDouble(int ordinal) {
+    return columns[ordinal].getDouble(rowId);
+  }
+
+  @Override
+  public Decimal getDecimal(int ordinal, int precision, int scale) {
+    return columns[ordinal].getDecimal(rowId, precision, scale);
+  }
+
+  @Override
+  public UTF8String getUTF8String(int ordinal) {
+    return columns[ordinal].getUTF8String(rowId);
+  }
+
+  @Override
+  public byte[] getBinary(int ordinal) {
+    return columns[ordinal].getBinary(rowId);
+  }
+
+  @Override
+  public CalendarInterval getInterval(int ordinal) {
+    return columns[ordinal].getInterval(rowId);
+  }
+
+  @Override
+  public ColumnarRow getStruct(int ordinal, int numFields) {
+    return columns[ordinal].getStruct(rowId);
+  }
+
+  @Override
+  public ColumnarArray getArray(int ordinal) {
+    return columns[ordinal].getArray(rowId);
+  }
+
+  @Override
+  public ColumnarMap getMap(int ordinal) {
+    return columns[ordinal].getMap(rowId);
+  }
+
+  @Override
+  public Object get(int ordinal, DataType dataType) {
+    if (dataType instanceof BooleanType) {
+      return getBoolean(ordinal);
+    } else if (dataType instanceof ByteType) {
+      return getByte(ordinal);
+    } else if (dataType instanceof ShortType) {
+      return getShort(ordinal);
+    } else if (dataType instanceof IntegerType) {
+      return getInt(ordinal);
+    } else if (dataType instanceof LongType) {
+      return getLong(ordinal);
+    } else if (dataType instanceof FloatType) {
+      return getFloat(ordinal);
+    } else if (dataType instanceof DoubleType) {
+      return getDouble(ordinal);
+    } else if (dataType instanceof StringType) {
+      return getUTF8String(ordinal);
+    } else if (dataType instanceof BinaryType) {
+      return getBinary(ordinal);
+    } else if (dataType instanceof DecimalType) {
+      DecimalType t = (DecimalType) dataType;
+      return getDecimal(ordinal, t.precision(), t.scale());
+    } else if (dataType instanceof DateType) {
+      return getInt(ordinal);
+    } else if (dataType instanceof TimestampType) {
+      return getLong(ordinal);
+    } else if (dataType instanceof ArrayType) {
+      return getArray(ordinal);
+    } else if (dataType instanceof StructType) {
+      return getStruct(ordinal, ((StructType) dataType).fields().length);
+    } else if (dataType instanceof MapType) {
+      return getMap(ordinal);
+    } else {
+      throw new UnsupportedOperationException("Datatype not supported " + 
dataType);
+    }
+  }
+
+  @Override
+  public void update(int ordinal, Object value) {
+    if (value == null) {
+      setNullAt(ordinal);
+    } else {
+      DataType dt = columns[ordinal].dataType();
+      if (dt instanceof BooleanType) {
+        setBoolean(ordinal, (boolean) value);
+      } else if (dt instanceof IntegerType) {
+        setInt(ordinal, (int) value);
+      } else if (dt instanceof ShortType) {
+        setShort(ordinal, (short) value);
+      } else if (dt instanceof LongType) {
+        setLong(ordinal, (long) value);
+      } else if (dt instanceof FloatType) {
+        setFloat(ordinal, (float) value);
+      } else if (dt instanceof DoubleType) {
+        setDouble(ordinal, (double) value);
+      } else if (dt instanceof DecimalType) {
+        DecimalType t = (DecimalType) dt;
+        Decimal d = Decimal.apply((BigDecimal) value, t.precision(), 
t.scale());
+        setDecimal(ordinal, d, t.precision());
+      } else if (dt instanceof CalendarIntervalType) {
+        setInterval(ordinal, (CalendarInterval) value);
+      } else if (dt instanceof StringType) {
+        setUTF8String(ordinal, (UTF8String) value);
+      } else if (dt instanceof BinaryType) {
+        setBinary(ordinal, (byte[]) value);
+      } else {
+        throw new UnsupportedOperationException("Datatype not supported " + 
dt);
+      }
+    }
+  }
+
+  @Override
+  public void setNullAt(int ordinal) {
+    columns[ordinal].putNull(rowId);
+  }
+
+  @Override
+  public void setBoolean(int ordinal, boolean value) {
+    columns[ordinal].putBoolean(rowId, value);
+  }
+
+  @Override
+  public void setByte(int ordinal, byte value) {
+    columns[ordinal].putByte(rowId, value);
+  }
+
+  @Override
+  public void setShort(int ordinal, short value) {
+    columns[ordinal].putShort(rowId, value);
+  }
+
+  @Override
+  public void setInt(int ordinal, int value) {
+    columns[ordinal].putInt(rowId, value);
+  }
+
+  @Override
+  public void setLong(int ordinal, long value) {
+    columns[ordinal].putLong(rowId, value);
+  }
+
+  @Override
+  public void setFloat(int ordinal, float value) {
+    columns[ordinal].putFloat(rowId, value);
+  }
+
+  @Override
+  public void setDouble(int ordinal, double value) {
+    columns[ordinal].putDouble(rowId, value);
+  }
+
+  @Override
+  public void setDecimal(int ordinal, Decimal value, int precision) {
+    columns[ordinal].putDecimal(rowId, value, precision);
+  }
+
+  @Override
+  public void setInterval(int ordinal, CalendarInterval value) {
+    columns[ordinal].putInterval(rowId, value);
+  }
+
+  public void setUTF8String(int ordinal, UTF8String value) {
+    columns[ordinal].putBytes(rowId, value.numBytes(), value.getBytes(), 0);
+  }
+
+  public void setBinary(int ordinal, byte[] value) {
+    columns[ordinal].putBytes(rowId, value.length, value, 0);
+  }
+}
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
new file mode 100644
index 0000000000..3216d4f3f9
--- /dev/null
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.vectorized.ArrowColumnarRow
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression, ExpressionsEvaluator}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.types.{DataType, StructType}
+
+// Not thread safe.
+abstract class ArrowProjection extends (InternalRow => ArrowColumnarRow) with 
ExpressionsEvaluator {
+  def currentValue: ArrowColumnarRow
+
+  /** Uses the given row to store the output of the projection. */
+  def target(row: ArrowColumnarRow): ArrowProjection
+}
+
+/** The factory object for `ArrowProjection`. */
+object ArrowProjection {
+
+  /**
+   * Returns an ArrowProjection for given StructType.
+   *
+   * CAUTION: the returned projection object is *not* thread-safe.
+   */
+  def create(schema: StructType): ArrowProjection = 
create(schema.fields.map(_.dataType))
+
+  /**
+   * Returns an ArrowProjection for given Array of DataTypes.
+   *
+   * CAUTION: the returned projection object is *not* thread-safe.
+   */
+  def create(fields: Array[DataType]): ArrowProjection = {
+    create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, nullable = 
true)))
+  }
+
+  /** Returns an ArrowProjection for given sequence of bound Expressions. */
+  def create(exprs: Seq[Expression]): ArrowProjection = {
+    InterpretedArrowProjection.createProjection(exprs)
+  }
+
+  def create(expr: Expression): ArrowProjection = create(Seq(expr))
+
+  /**
+   * Returns an ArrowProjection for given sequence of Expressions, which will 
be bound to
+   * `inputSchema`.
+   */
+  def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
ArrowProjection = {
+    create(bindReferences(exprs, inputSchema))
+  }
+}
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
new file mode 100644
index 0000000000..33a3802f88
--- /dev/null
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.vectorized.ArrowColumnarRow
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A [[ArrowProjection]] that is calculated by calling `eval` on each of the 
specified expressions.
+ *
+ * @param expressions
+ *   a sequence of expressions that determine the value of each column of the 
output row.
+ */
+class InterpretedArrowProjection(expressions: Seq[Expression]) extends 
ArrowProjection {
+  def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
+    this(bindReferences(expressions, inputSchema))
+
+  private[this] val subExprEliminationEnabled = 
SQLConf.get.subexpressionEliminationEnabled
+  private[this] val exprs = prepareExpressions(expressions, 
subExprEliminationEnabled)
+
+  override def initialize(partitionIndex: Int): Unit = {
+    initializeExprs(exprs, partitionIndex)
+  }
+
+  private[this] val validExprs = expressions.zipWithIndex.filter {
+    case (NoOp, _) => false
+    case _ => true
+  }
+
+  private[this] var mutableRow: ArrowColumnarRow = null
+
+  override def currentValue: ArrowColumnarRow = mutableRow
+
+  override def target(row: ArrowColumnarRow): ArrowProjection = {
+    mutableRow = row
+    this
+  }
+
+  private def getStringWriter(ordinal: Int, dt: DataType): (ArrowColumnarRow, 
Any) => Unit =
+    dt match {
+      case StringType => (input, v) => input.setUTF8String(ordinal, 
v.asInstanceOf[UTF8String])
+      case BinaryType => (input, v) => input.setBinary(ordinal, 
v.asInstanceOf[Array[Byte]])
+      case _ => (input, v) => input.update(ordinal, v)
+    }
+
+  private[this] val fieldWriters: Array[Any => Unit] = validExprs.map {
+    case (e, i) =>
+      val writer = if (e.dataType.isInstanceOf[StringType] || 
e.dataType.isInstanceOf[BinaryType]) {
+        getStringWriter(i, e.dataType)
+      } else InternalRow.getWriter(i, e.dataType)
+      if (!e.nullable) { (v: Any) => writer(mutableRow, v) }
+      else {
+        (v: Any) =>
+          {
+            if (v == null) {
+              mutableRow.setNullAt(i)
+            } else {
+              writer(mutableRow, v)
+            }
+          }
+      }
+  }.toArray
+
+  override def apply(input: InternalRow): ArrowColumnarRow = {
+    if (subExprEliminationEnabled) {
+      runtime.setInput(input)
+    }
+
+    var i = 0
+    while (i < validExprs.length) {
+      val (_, ordinal) = validExprs(i)
+      // Store the result into buffer first, to make the projection atomic 
(needed by aggregation)
+      fieldWriters(i)(exprs(ordinal).eval(input))
+      i += 1
+    }
+    mutableRow
+  }
+}
+
+/** Helper functions for creating an [[InterpretedArrowProjection]]. */
+object InterpretedArrowProjection {
+
+  /** Returns a [[ArrowProjection]] for given sequence of bound Expressions. */
+  def createProjection(exprs: Seq[Expression]): ArrowProjection = {
+    new InterpretedArrowProjection(exprs)
+  }
+}
diff --git 
a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
 
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
new file mode 100644
index 0000000000..1469f57e64
--- /dev/null
+++ 
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+// A helper class to evaluate expressions.
+trait ExpressionsEvaluator {
+  protected lazy val runtime =
+    new 
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
+
+  protected def prepareExpressions(
+      exprs: Seq[Expression],
+      subExprEliminationEnabled: Boolean): Seq[Expression] = {
+    // We need to make sure that we do not reuse stateful expressions.
+    // Different with Spark 3.4 above, without cleanedExpression for stateful 
expression.
+    if (subExprEliminationEnabled) {
+      runtime.proxyExpressions(exprs)
+    } else {
+      exprs
+    }
+  }
+
+  /**
+   * Initializes internal states given the current partition index. This is 
used by nondeterministic
+   * expressions to set initial states. The default implementation does 
nothing.
+   */
+  def initialize(partitionIndex: Int): Unit = {}
+
+  protected def initializeExprs(exprs: Seq[Expression], partitionIndex: Int): 
Unit = {
+    exprs.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
+}
diff --git 
a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
 
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
new file mode 100644
index 0000000000..1469f57e64
--- /dev/null
+++ 
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+// A helper class to evaluate expressions.
+trait ExpressionsEvaluator {
+  protected lazy val runtime =
+    new 
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
+
+  protected def prepareExpressions(
+      exprs: Seq[Expression],
+      subExprEliminationEnabled: Boolean): Seq[Expression] = {
+    // We need to make sure that we do not reuse stateful expressions.
+    // Different with Spark 3.4 above, without cleanedExpression for stateful 
expression.
+    if (subExprEliminationEnabled) {
+      runtime.proxyExpressions(exprs)
+    } else {
+      exprs
+    }
+  }
+
+  /**
+   * Initializes internal states given the current partition index. This is 
used by nondeterministic
+   * expressions to set initial states. The default implementation does 
nothing.
+   */
+  def initialize(partitionIndex: Int): Unit = {}
+
+  protected def initializeExprs(exprs: Seq[Expression], partitionIndex: Int): 
Unit = {
+    exprs.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to