This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 932d8a261f [GLUTEN-7359][VL] Optimize string in partial project (#7592)
932d8a261f is described below
commit 932d8a261fa0e0537bd191dd05bc3b7988d31acc
Author: Jin Chengcheng <[email protected]>
AuthorDate: Mon Oct 28 09:32:40 2024 +0800
[GLUTEN-7359][VL] Optimize string in partial project (#7592)
---
.../execution/ColumnarPartialProjectExec.scala | 95 ++----
.../gluten/expression/UDFPartialProjectSuite.scala | 8 +
.../apache/gluten/vectorized/ArrowColumnarRow.java | 319 +++++++++++++++++++++
.../apache/gluten/expression/ArrowProjection.scala | 67 +++++
.../expression/InterpretedArrowProjection.scala | 108 +++++++
.../expressions/ExpressionsEvaluator.scala | 50 ++++
.../expressions/ExpressionsEvaluator.scala | 50 ++++
7 files changed, 626 insertions(+), 71 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 1c394103db..d993e399db 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -18,19 +18,18 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
-import org.apache.gluten.expression.ExpressionUtils
+import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils}
import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.vectorized.ArrowWritableColumnVector
+import org.apache.gluten.vectorized.{ArrowColumnarRow,
ArrowWritableColumnVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction,
MutableProjection, NamedExpression, NaNvl, ScalaUDF, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction,
NamedExpression, NaNvl, ScalaUDF}
import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow,
WritableColumnVector}
import org.apache.spark.sql.hive.HiveUdfUtil
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType,
DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType,
NullType, ShortType, StringType, TimestampType, YearMonthIntervalType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -38,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch,
ColumnVector}
import scala.collection.mutable.ListBuffer
/**
- * By rule <PartialProhectRule>, the project not offload-able that is changed
to
+ * By rule <PartialProjectRule>, the project not offload-able that is changed
to
* ProjectExecTransformer + ColumnarPartialProjectExec e.g. sum(myudf(a) + b +
hash(c)), child is
* (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as
_SparkPartialProject1),
* ProjectExecTransformer(_SparkPartialProject1 + b + hash(c))
@@ -64,12 +63,12 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
@transient override lazy val metrics = Map(
"time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of
partial project"),
- "column_to_row_time" -> SQLMetrics.createTimingMetric(
+ "velox_to_arrow_time" -> SQLMetrics.createTimingMetric(
sparkContext,
- "time of velox to Arrow ColumnarBatch or UnsafeRow"),
- "row_to_column_time" -> SQLMetrics.createTimingMetric(
+ "time of velox to Arrow ColumnarBatch"),
+ "arrow_to_velox_time" -> SQLMetrics.createTimingMetric(
sparkContext,
- "time of Arrow ColumnarBatch or UnsafeRow to velox")
+ "time of Arrow ColumnarBatch to velox")
)
override def output: Seq[Attribute] = child.output ++
replacedAliasUdf.map(_.toAttribute)
@@ -111,22 +110,26 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
}
private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
- exprs.foreach {
+ exprs.forall {
case a: AttributeReference =>
val index = child.output.indexWhere(s => s.exprId.equals(a.exprId))
// Some child operator as HashAggregateTransformer will not have udf
child column
if (index < 0) {
UDFAttrNotExists = true
log.debug(s"Expression $a should exist in child output
${child.output}")
- return
+ false
} else if (!validateDataType(a.dataType)) {
hasUnsupportedDataType = true
log.debug(s"Expression $a contains unsupported data type
${a.dataType}")
+ false
} else if (!projectIndexInChild.contains(index)) {
projectAttributes.append(a.toAttribute)
projectIndexInChild.append(index)
- }
- case p => getProjectIndexInChildOutput(p.children)
+ true
+ } else true
+ case p =>
+ getProjectIndexInChildOutput(p.children)
+ true
}
}
@@ -150,7 +153,7 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
return ValidationResult.failed("No UDF")
}
if (replacedAliasUdf.size > original.output.size) {
- // e.g. udf1(col) + udf2(col), it will introduce 2 cols for r2c
+ // e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c
return ValidationResult.failed("Number of RowToColumn columns is more
than ProjectExec")
}
if (!original.projectList.forall(validateExpression(_))) {
@@ -168,9 +171,8 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val totalTime = longMetric("time")
- val c2r = longMetric("column_to_row_time")
- val r2c = longMetric("row_to_column_time")
- val isMutable = canUseMutableProjection()
+ val c2a = longMetric("velox_to_arrow_time")
+ val a2c = longMetric("arrow_to_velox_time")
child.executeColumnar().mapPartitions {
batches =>
val res: Iterator[Iterator[ColumnarBatch]] = new
Iterator[Iterator[ColumnarBatch]] {
@@ -183,9 +185,8 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
} else {
val start = System.currentTimeMillis()
val childData = ColumnarBatches.select(batch,
projectIndexInChild.toArray)
- val projectedBatch = if (isMutable) {
- getProjectedBatchArrow(childData, c2r, r2c)
- } else getProjectedBatch(childData, c2r, r2c)
+ val projectedBatch = getProjectedBatchArrow(childData, c2a, a2c)
+
val batchIterator = projectedBatch.map {
b =>
if (b.numCols() != 0) {
@@ -214,60 +215,12 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
}
}
- // scalastyle:off line.size.limit
- // String type cannot use MutableProjection
- // Otherwise will throw java.lang.UnsupportedOperationException: Datatype
not supported StringType
- // at
org.apache.spark.sql.execution.vectorized.MutableColumnarRow.update(MutableColumnarRow.java:224)
- // at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
Source)
- // scalastyle:on line.size.limit
- private def canUseMutableProjection(): Boolean = {
- replacedAliasUdf.forall(
- r =>
- r.dataType match {
- case StringType | BinaryType => false
- case _ => true
- })
- }
-
- /**
- * add c2r and r2c for unsupported expression child data c2r get
Iterator[InternalRow], then call
- * Spark project, then r2c
- */
- private def getProjectedBatch(
- childData: ColumnarBatch,
- c2r: SQLMetric,
- r2c: SQLMetric): Iterator[ColumnarBatch] = {
- // select part of child output and child data
- val proj = UnsafeProjection.create(replacedAliasUdf,
projectAttributes.toSeq)
- val numOutputRows = new SQLMetric("numOutputRows")
- val numInputBatches = new SQLMetric("numInputBatches")
- val rows = VeloxColumnarToRowExec
- .toRowIterator(
- Iterator.single[ColumnarBatch](childData),
- projectAttributes.toSeq,
- numOutputRows,
- numInputBatches,
- c2r)
- .map(proj)
-
- val schema =
-
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
- RowToVeloxColumnarExec.toColumnarBatchIterator(
- rows,
- schema,
- numOutputRows,
- numInputBatches,
- r2c,
- childData.numRows())
- // TODO: should check the size <= 1, but now it has bug, will change
iterator to empty
- }
-
private def getProjectedBatchArrow(
childData: ColumnarBatch,
c2a: SQLMetric,
a2c: SQLMetric): Iterator[ColumnarBatch] = {
// select part of child output and child data
- val proj = MutableProjection.create(replacedAliasUdf,
projectAttributes.toSeq)
+ val proj = ArrowProjection.create(replacedAliasUdf,
projectAttributes.toSeq)
val numRows = childData.numRows()
val start = System.currentTimeMillis()
val arrowBatch = if (childData.numCols() == 0) {
@@ -279,14 +232,14 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
val schema =
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
- val vectors: Array[WritableColumnVector] = ArrowWritableColumnVector
+ val vectors: Array[ArrowWritableColumnVector] = ArrowWritableColumnVector
.allocateColumns(numRows, schema)
.map {
vector =>
vector.setValueCount(numRows)
vector
}
- val targetRow = new MutableColumnarRow(vectors)
+ val targetRow = new ArrowColumnarRow(vectors)
for (i <- 0 until numRows) {
targetRow.rowId = i
proj.target(targetRow).apply(arrowBatch.getRow(i))
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 757d4da131..4aaa722c68 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -70,6 +70,8 @@ abstract class UDFPartialProjectSuite extends
WholeStageTransformerSuite {
spark.udf.register("plus_one", plusOne)
val noArgument = udf(() => 15)
spark.udf.register("no_argument", noArgument)
+ val concat = udf((x: String) => x + "_concat")
+ spark.udf.register("concat_concat", concat)
}
@@ -139,4 +141,10 @@ abstract class UDFPartialProjectSuite extends
WholeStageTransformerSuite {
}
}
+ test("test concat with string") {
+ runQueryAndCompare("SELECT concat_concat(l_comment), hash(l_partkey) from
lineitem") {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
new file mode 100644
index 0000000000..e0ac5858e7
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.BinaryType;
+import org.apache.spark.sql.types.BooleanType;
+import org.apache.spark.sql.types.ByteType;
+import org.apache.spark.sql.types.CalendarIntervalType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DateType;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.DoubleType;
+import org.apache.spark.sql.types.FloatType;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.LongType;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.ShortType;
+import org.apache.spark.sql.types.StringType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.TimestampType;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import java.math.BigDecimal;
+
+// Copy from Spark MutableColumnarRow mostly but class member columns`type is
+// ArrowWritableColumnVector. And support string and binary type to write,
Arrow writer does not
+// need to setNotNull before write a value.
+public final class ArrowColumnarRow extends InternalRow {
+ public int rowId;
+ private final ArrowWritableColumnVector[] columns;
+
+ public ArrowColumnarRow(ArrowWritableColumnVector[] writableColumns) {
+ this.columns = writableColumns;
+ }
+
+ @Override
+ public int numFields() {
+ return columns.length;
+ }
+
+ @Override
+ public InternalRow copy() {
+ GenericInternalRow row = new GenericInternalRow(columns.length);
+ for (int i = 0; i < numFields(); i++) {
+ if (isNullAt(i)) {
+ row.setNullAt(i);
+ } else {
+ DataType dt = columns[i].dataType();
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof ByteType) {
+ row.setByte(i, getByte(i));
+ } else if (dt instanceof ShortType) {
+ row.setShort(i, getShort(i));
+ } else if (dt instanceof IntegerType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof LongType) {
+ row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
+ } else if (dt instanceof DoubleType) {
+ row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i).copy());
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType) dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()),
t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof TimestampType) {
+ row.setLong(i, getLong(i));
+ } else if (dt instanceof StructType) {
+ row.update(i, getStruct(i, ((StructType)
dt).fields().length).copy());
+ } else if (dt instanceof ArrayType) {
+ row.update(i, getArray(i).copy());
+ } else if (dt instanceof MapType) {
+ row.update(i, getMap(i).copy());
+ } else {
+ throw new RuntimeException("Not implemented. " + dt);
+ }
+ }
+ }
+ return row;
+ }
+
+ @Override
+ public boolean anyNull() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return columns[ordinal].isNullAt(rowId);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return columns[ordinal].getBoolean(rowId);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return columns[ordinal].getByte(rowId);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return columns[ordinal].getShort(rowId);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return columns[ordinal].getInt(rowId);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return columns[ordinal].getLong(rowId);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return columns[ordinal].getFloat(rowId);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return columns[ordinal].getDouble(rowId);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return columns[ordinal].getDecimal(rowId, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return columns[ordinal].getUTF8String(rowId);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return columns[ordinal].getBinary(rowId);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return columns[ordinal].getInterval(rowId);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return columns[ordinal].getStruct(rowId);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return columns[ordinal].getArray(rowId);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return columns[ordinal].getMap(rowId);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ if (dataType instanceof BooleanType) {
+ return getBoolean(ordinal);
+ } else if (dataType instanceof ByteType) {
+ return getByte(ordinal);
+ } else if (dataType instanceof ShortType) {
+ return getShort(ordinal);
+ } else if (dataType instanceof IntegerType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof LongType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof FloatType) {
+ return getFloat(ordinal);
+ } else if (dataType instanceof DoubleType) {
+ return getDouble(ordinal);
+ } else if (dataType instanceof StringType) {
+ return getUTF8String(ordinal);
+ } else if (dataType instanceof BinaryType) {
+ return getBinary(ordinal);
+ } else if (dataType instanceof DecimalType) {
+ DecimalType t = (DecimalType) dataType;
+ return getDecimal(ordinal, t.precision(), t.scale());
+ } else if (dataType instanceof DateType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof TimestampType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof ArrayType) {
+ return getArray(ordinal);
+ } else if (dataType instanceof StructType) {
+ return getStruct(ordinal, ((StructType) dataType).fields().length);
+ } else if (dataType instanceof MapType) {
+ return getMap(ordinal);
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " +
dataType);
+ }
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
+ DataType dt = columns[ordinal].dataType();
+ if (dt instanceof BooleanType) {
+ setBoolean(ordinal, (boolean) value);
+ } else if (dt instanceof IntegerType) {
+ setInt(ordinal, (int) value);
+ } else if (dt instanceof ShortType) {
+ setShort(ordinal, (short) value);
+ } else if (dt instanceof LongType) {
+ setLong(ordinal, (long) value);
+ } else if (dt instanceof FloatType) {
+ setFloat(ordinal, (float) value);
+ } else if (dt instanceof DoubleType) {
+ setDouble(ordinal, (double) value);
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType) dt;
+ Decimal d = Decimal.apply((BigDecimal) value, t.precision(),
t.scale());
+ setDecimal(ordinal, d, t.precision());
+ } else if (dt instanceof CalendarIntervalType) {
+ setInterval(ordinal, (CalendarInterval) value);
+ } else if (dt instanceof StringType) {
+ setUTF8String(ordinal, (UTF8String) value);
+ } else if (dt instanceof BinaryType) {
+ setBinary(ordinal, (byte[]) value);
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " +
dt);
+ }
+ }
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ columns[ordinal].putNull(rowId);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ columns[ordinal].putBoolean(rowId, value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ columns[ordinal].putByte(rowId, value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ columns[ordinal].putShort(rowId, value);
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ columns[ordinal].putInt(rowId, value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ columns[ordinal].putLong(rowId, value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ columns[ordinal].putFloat(rowId, value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ columns[ordinal].putDouble(rowId, value);
+ }
+
+ @Override
+ public void setDecimal(int ordinal, Decimal value, int precision) {
+ columns[ordinal].putDecimal(rowId, value, precision);
+ }
+
+ @Override
+ public void setInterval(int ordinal, CalendarInterval value) {
+ columns[ordinal].putInterval(rowId, value);
+ }
+
+ public void setUTF8String(int ordinal, UTF8String value) {
+ columns[ordinal].putBytes(rowId, value.numBytes(), value.getBytes(), 0);
+ }
+
+ public void setBinary(int ordinal, byte[] value) {
+ columns[ordinal].putBytes(rowId, value.length, value, 0);
+ }
+}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
new file mode 100644
index 0000000000..3216d4f3f9
--- /dev/null
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.vectorized.ArrowColumnarRow
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
Expression, ExpressionsEvaluator}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.types.{DataType, StructType}
+
+// Not thread safe.
+abstract class ArrowProjection extends (InternalRow => ArrowColumnarRow) with
ExpressionsEvaluator {
+ def currentValue: ArrowColumnarRow
+
+ /** Uses the given row to store the output of the projection. */
+ def target(row: ArrowColumnarRow): ArrowProjection
+}
+
+/** The factory object for `ArrowProjection`. */
+object ArrowProjection {
+
+ /**
+ * Returns an ArrowProjection for given StructType.
+ *
+ * CAUTION: the returned projection object is *not* thread-safe.
+ */
+ def create(schema: StructType): ArrowProjection =
create(schema.fields.map(_.dataType))
+
+ /**
+ * Returns an ArrowProjection for given Array of DataTypes.
+ *
+ * CAUTION: the returned projection object is *not* thread-safe.
+ */
+ def create(fields: Array[DataType]): ArrowProjection = {
+ create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, nullable =
true)))
+ }
+
+ /** Returns an ArrowProjection for given sequence of bound Expressions. */
+ def create(exprs: Seq[Expression]): ArrowProjection = {
+ InterpretedArrowProjection.createProjection(exprs)
+ }
+
+ def create(expr: Expression): ArrowProjection = create(Seq(expr))
+
+ /**
+ * Returns an ArrowProjection for given sequence of Expressions, which will
be bound to
+ * `inputSchema`.
+ */
+ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]):
ArrowProjection = {
+ create(bindReferences(exprs, inputSchema))
+ }
+}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
new file mode 100644
index 0000000000..33a3802f88
--- /dev/null
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.vectorized.ArrowColumnarRow
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A [[ArrowProjection]] that is calculated by calling `eval` on each of the
specified expressions.
+ *
+ * @param expressions
+ * a sequence of expressions that determine the value of each column of the
output row.
+ */
+class InterpretedArrowProjection(expressions: Seq[Expression]) extends
ArrowProjection {
+ def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
+ this(bindReferences(expressions, inputSchema))
+
+ private[this] val subExprEliminationEnabled =
SQLConf.get.subexpressionEliminationEnabled
+ private[this] val exprs = prepareExpressions(expressions,
subExprEliminationEnabled)
+
+ override def initialize(partitionIndex: Int): Unit = {
+ initializeExprs(exprs, partitionIndex)
+ }
+
+ private[this] val validExprs = expressions.zipWithIndex.filter {
+ case (NoOp, _) => false
+ case _ => true
+ }
+
+ private[this] var mutableRow: ArrowColumnarRow = null
+
+ override def currentValue: ArrowColumnarRow = mutableRow
+
+ override def target(row: ArrowColumnarRow): ArrowProjection = {
+ mutableRow = row
+ this
+ }
+
+ private def getStringWriter(ordinal: Int, dt: DataType): (ArrowColumnarRow,
Any) => Unit =
+ dt match {
+ case StringType => (input, v) => input.setUTF8String(ordinal,
v.asInstanceOf[UTF8String])
+ case BinaryType => (input, v) => input.setBinary(ordinal,
v.asInstanceOf[Array[Byte]])
+ case _ => (input, v) => input.update(ordinal, v)
+ }
+
+ private[this] val fieldWriters: Array[Any => Unit] = validExprs.map {
+ case (e, i) =>
+ val writer = if (e.dataType.isInstanceOf[StringType] ||
e.dataType.isInstanceOf[BinaryType]) {
+ getStringWriter(i, e.dataType)
+ } else InternalRow.getWriter(i, e.dataType)
+ if (!e.nullable) { (v: Any) => writer(mutableRow, v) }
+ else {
+ (v: Any) =>
+ {
+ if (v == null) {
+ mutableRow.setNullAt(i)
+ } else {
+ writer(mutableRow, v)
+ }
+ }
+ }
+ }.toArray
+
+ override def apply(input: InternalRow): ArrowColumnarRow = {
+ if (subExprEliminationEnabled) {
+ runtime.setInput(input)
+ }
+
+ var i = 0
+ while (i < validExprs.length) {
+ val (_, ordinal) = validExprs(i)
+ // Store the result into buffer first, to make the projection atomic
(needed by aggregation)
+ fieldWriters(i)(exprs(ordinal).eval(input))
+ i += 1
+ }
+ mutableRow
+ }
+}
+
+/** Helper functions for creating an [[InterpretedArrowProjection]]. */
+object InterpretedArrowProjection {
+
+ /** Returns a [[ArrowProjection]] for given sequence of bound Expressions. */
+ def createProjection(exprs: Seq[Expression]): ArrowProjection = {
+ new InterpretedArrowProjection(exprs)
+ }
+}
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
new file mode 100644
index 0000000000..1469f57e64
--- /dev/null
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+// A helper class to evaluate expressions.
+trait ExpressionsEvaluator {
+ protected lazy val runtime =
+ new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
+
+ protected def prepareExpressions(
+ exprs: Seq[Expression],
+ subExprEliminationEnabled: Boolean): Seq[Expression] = {
+ // We need to make sure that we do not reuse stateful expressions.
+ // Different with Spark 3.4 above, without cleanedExpression for stateful
expression.
+ if (subExprEliminationEnabled) {
+ runtime.proxyExpressions(exprs)
+ } else {
+ exprs
+ }
+ }
+
+ /**
+ * Initializes internal states given the current partition index. This is
used by nondeterministic
+ * expressions to set initial states. The default implementation does
nothing.
+ */
+ def initialize(partitionIndex: Int): Unit = {}
+
+ protected def initializeExprs(exprs: Seq[Expression], partitionIndex: Int):
Unit = {
+ exprs.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+}
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
new file mode 100644
index 0000000000..1469f57e64
--- /dev/null
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+// A helper class to evaluate expressions.
+trait ExpressionsEvaluator {
+ protected lazy val runtime =
+ new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
+
+ protected def prepareExpressions(
+ exprs: Seq[Expression],
+ subExprEliminationEnabled: Boolean): Seq[Expression] = {
+ // We need to make sure that we do not reuse stateful expressions.
+ // Different with Spark 3.4 above, without cleanedExpression for stateful
expression.
+ if (subExprEliminationEnabled) {
+ runtime.proxyExpressions(exprs)
+ } else {
+ exprs
+ }
+ }
+
+ /**
+ * Initializes internal states given the current partition index. This is
used by nondeterministic
+ * expressions to set initial states. The default implementation does
nothing.
+ */
+ def initialize(partitionIndex: Int): Unit = {}
+
+ protected def initializeExprs(exprs: Seq[Expression], partitionIndex: Int):
Unit = {
+ exprs.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]