This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2d33f383ba [GLUTEN-8668][VL] Support complex type in
ColumnarPartialProject (#8669)
2d33f383ba is described below
commit 2d33f383ba9a3a9ea3d3326bd3b4c7a39d5a6b38
Author: WangGuangxin <[email protected]>
AuthorDate: Thu Feb 20 16:54:34 2025 +0800
[GLUTEN-8668][VL] Support complex type in ColumnarPartialProject (#8669)
---
.../execution/ColumnarPartialProjectExec.scala | 34 +--
.../java/org/apache/gluten/udf}/CustomerUDF.java | 2 +-
.../gluten/expression/UDFPartialProjectSuite.scala | 57 ++++++
.../spark/sql/execution/GlutenHiveUDFSuite.scala | 203 ++++++++++++++++++
.../apache/gluten/vectorized/ArrowColumnarRow.java | 18 ++
.../vectorized/ArrowWritableColumnVector.java | 227 ++++++++++++++++++++-
.../expression/InterpretedArrowProjection.scala | 42 ++--
.../sql/hive/execution/GlutenHiveUDFSuite.scala | 139 -------------
8 files changed, 521 insertions(+), 201 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 500c70fefa..9d6077ec62 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -29,12 +29,11 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow,
ArrowWritableColumnVector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction,
NamedExpression, NaNvl, ScalaUDF}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.hive.HiveUdfUtil
-import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import scala.collection.mutable.ListBuffer
@@ -59,10 +58,8 @@ case class ColumnarPartialProjectExec(original: ProjectExec,
child: SparkPlan)(
private val projectAttributes: ListBuffer[Attribute] = ListBuffer()
private val projectIndexInChild: ListBuffer[Int] = ListBuffer()
private var UDFAttrNotExists = false
- private var hasUnsupportedDataType = replacedAliasUdf.exists(a =>
!validateDataType(a.dataType))
- if (!hasUnsupportedDataType) {
- getProjectIndexInChildOutput(replacedAliasUdf)
- }
+ private var hasUnsupportedDataType = false
+ getProjectIndexInChildOutput(replacedAliasUdf)
@transient override lazy val metrics = Map(
"time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of
partial project"),
@@ -102,26 +99,6 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
.forall(validateExpression)
}
- private def validateDataType(dataType: DataType): Boolean = {
- dataType match {
- case _: BooleanType => true
- case _: ByteType => true
- case _: ShortType => true
- case _: IntegerType => true
- case _: LongType => true
- case _: FloatType => true
- case _: DoubleType => true
- case _: StringType => true
- case _: TimestampType => true
- case _: DateType => true
- case _: BinaryType => true
- case _: DecimalType => true
- case YearMonthIntervalType.DEFAULT => true
- case _: NullType => true
- case _ => false
- }
- }
-
private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
exprs.forall {
case a: AttributeReference =>
@@ -131,7 +108,9 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
UDFAttrNotExists = true
log.debug(s"Expression $a should exist in child output
${child.output}")
false
- } else if (!validateDataType(a.dataType)) {
+ } else if (
+
BackendsApiManager.getValidatorApiInstance.doSchemaValidate(a.dataType).isDefined
+ ) {
hasUnsupportedDataType = true
log.debug(s"Expression $a contains unsupported data type
${a.dataType}")
false
@@ -253,6 +232,7 @@ case class ColumnarPartialProjectExec(original:
ProjectExec, child: SparkPlan)(
targetRow.rowId = i
proj.target(targetRow).apply(arrowBatch.getRow(i))
}
+ targetRow.finishWriteRow()
val targetBatch = new
ColumnarBatch(vectors.map(_.asInstanceOf[ColumnVector]), numRows)
val start2 = System.currentTimeMillis()
val veloxBatch = VeloxColumnarBatches.toVeloxBatch(
diff --git
a/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
b/backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
similarity index 97%
rename from
gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
rename to backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
index 257bd07021..b677710dbc 100644
---
a/gluten-ut/spark35/src/test/java/org/apache/gluten/execution/CustomerUDF.java
+++ b/backends-velox/src/test/java/org/apache/gluten/udf/CustomerUDF.java
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.execution;
+package org.apache.gluten.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 30b8cb0f69..3f1a68993f 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -24,6 +24,8 @@ import org.apache.spark.sql.functions.udf
import java.io.File
+case class MyStruct(a: Long, b: Array[Long])
+
class UDFPartialProjectSuiteRasOff extends UDFPartialProjectSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
@@ -147,4 +149,59 @@ abstract class UDFPartialProjectSuite extends
WholeStageTransformerSuite {
}
}
+ test("udf with array") {
+ spark.udf.register("array_plus_one", udf((arr: Array[Int]) => arr.map(_ +
1)))
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | sort_array(array_plus_one(array_data)) as
orderkey_arr_plus_one
+ |FROM (
+ | SELECT l_partkey, collect_list(l_orderkey) as
array_data
+ | FROM lineitem
+ | GROUP BY l_partkey
+ |)
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("udf with map") {
+ spark.udf.register(
+ "map_value_plus_one",
+ udf((m: Map[String, Long]) => m.map { case (key, value) => key -> (value
+ 1) }))
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | map_value_plus_one(map_data)
+ |FROM (
+ | SELECT l_partkey,
+ | map(
+ | concat('hello', l_orderkey % 2), l_orderkey,
+ | concat('world', l_orderkey % 2), l_orderkey
+ | ) as map_data
+ | FROM lineitem
+ |)
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+
+ test("udf with struct and array") {
+ spark.udf.register("struct_plus_one", udf((m: MyStruct) => MyStruct(m.a +
1, m.b.map(_ + 1))))
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | struct_plus_one(struct_data)
+ |FROM (
+ | SELECT l_partkey,
+ | struct(
+ | l_orderkey % 2 as a,
+ | array(l_orderkey % 2, l_orderkey % 2 + 1,
l_orderkey % 2 + 2) as b
+ | ) as struct_data
+ | FROM lineitem
+ |)
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
new file mode 100644
index 0000000000..c8404fb93b
--- /dev/null
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.gluten.execution.ColumnarPartialProjectExec
+import org.apache.gluten.udf.CustomerUDF
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config
+import org.apache.spark.internal.config.UI.UI_ENABLED
+import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
+import org.apache.spark.sql.hive.HiveUtils
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.test.SQLTestUtils
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+class GlutenHiveUDFSuite extends GlutenQueryTest with SQLTestUtils {
+ private var _spark: SparkSession = _
+
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+
+ if (_spark == null) {
+ _spark =
SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
+ }
+
+ _spark.sparkContext.setLogLevel("info")
+
+ createTestTable()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ }
+
+ override protected def spark: SparkSession = _spark
+
+ protected def defaultSparkConf: SparkConf = {
+ val conf = new SparkConf()
+ .set("spark.master", "local[1]")
+ .set("spark.sql.test", "")
+ .set("spark.sql.testkey", "true")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+ .set(SQLConf.CODEGEN_FACTORY_MODE.key,
CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
+ .set(
+ HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
+ "org.apache.spark.sql.hive.execution.PairSerDe")
+ // SPARK-8910
+ .set(UI_ENABLED, false)
+ .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
+ // Hive changed the default of
hive.metastore.disallow.incompatible.col.type.changes
+ // from false to true. For details, see the JIRA HIVE-12320 and
HIVE-17764.
+
.set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes",
"false")
+ // Disable ConvertToLocalRelation for better test coverage. Test cases
built on
+ // LocalRelation will exercise the optimization rules better by
disabling it as
+ // this rule may potentially block testing of other optimization rules
such as
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
ConvertToLocalRelation.ruleName)
+
+ conf.set(
+ StaticSQLConf.WAREHOUSE_PATH,
+ conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
+ }
+
+ protected def sparkConf: SparkConf = {
+ defaultSparkConf
+ .set("spark.plugins", "org.apache.gluten.GlutenPlugin")
+ .set("spark.default.parallelism", "1")
+ .set("spark.memory.offHeap.enabled", "true")
+ .set("spark.memory.offHeap.size", "1024MB")
+ .set("spark.gluten.sql.native.writer.enabled", "true")
+ }
+
+ private def withTempFunction(funcName: String)(f: => Unit): Unit = {
+ try f
+ finally sql(s"DROP TEMPORARY FUNCTION IF EXISTS $funcName")
+ }
+
+ private def checkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag:
ClassTag[T]): Unit = {
+ val executedPlan = getExecutedPlan(df)
+ assert(executedPlan.exists(plan => plan.getClass == tag.runtimeClass))
+ }
+
+ private def createTestTable(): Unit = {
+ val table = "lineitem"
+ val tableDir = getClass.getResource("/tpch-data-parquet").getFile
+ val tablePath = new File(tableDir, table).getAbsolutePath
+ val tableDF = spark.read.format("parquet").load(tablePath)
+ tableDF.createOrReplaceTempView(table)
+ }
+
+ test("customer udf") {
+ withTempFunction("testUDF") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
+ val df = sql("select l_partkey, testUDF(l_comment) from lineitem")
+ df.show()
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
+ }
+
+ test("customer udf wrapped in function") {
+ withTempFunction("testUDF") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
+ val df = sql("select l_partkey, hash(testUDF(l_comment)) from lineitem")
+ df.show()
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
+ }
+
+ test("example") {
+ withTempFunction("testUDF") {
+ sql("CREATE TEMPORARY FUNCTION testUDF AS
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
+ val df = sql("select testUDF('l_commen', 1, 5)")
+ df.show()
+ // It should not be converted to ColumnarPartialProjectExec, since
+ // the UDF need all the columns in child output.
+ assert(!getExecutedPlan(df).exists {
+ case _: ColumnarPartialProjectExec => true
+ case _ => false
+ })
+ }
+ }
+
+ test("udf with array") {
+ withTempFunction("udf_sort_array") {
+ sql("""
+ |CREATE TEMPORARY FUNCTION udf_sort_array AS
+ |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFSortArray';
+ |""".stripMargin)
+
+ val df = sql("""
+ |SELECT
+ | l_orderkey,
+ | l_partkey,
+ | udf_sort_array(array(10, l_orderkey, 1)) as udf_result
+ |FROM lineitem WHERE l_partkey <= 5 and l_orderkey <1000
+ |""".stripMargin)
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(35, 5, mutable.WrappedArray.make(Array(1, 10, 35))),
+ Row(321, 4, mutable.WrappedArray.make(Array(1, 10, 321))),
+ Row(548, 2, mutable.WrappedArray.make(Array(1, 10, 548))),
+ Row(640, 5, mutable.WrappedArray.make(Array(1, 10, 640))),
+ Row(807, 2, mutable.WrappedArray.make(Array(1, 10, 807)))
+ )
+ )
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
+ }
+
+ test("udf with map") {
+ withTempFunction("udf_str_to_map") {
+ sql("""
+ |CREATE TEMPORARY FUNCTION udf_str_to_map AS
+ |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFStringToMap';
+ |""".stripMargin)
+
+ val df = sql(
+ """
+ |SELECT
+ | l_orderkey,
+ | l_partkey,
+ | udf_str_to_map(
+ | concat_ws(',', array(concat('hello', l_partkey), 'world')),
',', 'l') as udf_result
+ |FROM lineitem WHERE l_partkey <= 5 and l_orderkey <1000
+ |""".stripMargin)
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(321, 4, Map("he" -> "lo4", "wor" -> "d")),
+ Row(35, 5, Map("he" -> "lo5", "wor" -> "d")),
+ Row(548, 2, Map("he" -> "lo2", "wor" -> "d")),
+ Row(640, 5, Map("he" -> "lo5", "wor" -> "d")),
+ Row(807, 2, Map("he" -> "lo2", "wor" -> "d"))
+ )
+ )
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
+ }
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
index e0ac5858e7..907ce1ca73 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarRow.java
@@ -16,6 +16,8 @@
*/
package org.apache.gluten.vectorized;
+import org.apache.gluten.exception.GlutenException;
+
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.types.ArrayType;
@@ -316,4 +318,20 @@ public final class ArrowColumnarRow extends InternalRow {
public void setBinary(int ordinal, byte[] value) {
columns[ordinal].putBytes(rowId, value.length, value, 0);
}
+
+ public void writeRow(GenericInternalRow input) {
+ if (input.numFields() != columns.length) {
+ throw new GlutenException(
+ "The numFields of input row should be equal to the number of column
vector!");
+ }
+ for (int i = 0; i < input.numFields(); ++i) {
+ columns[i].write(input, i);
+ }
+ }
+
+ public void finishWriteRow() {
+ for (int i = 0; i < columns.length; ++i) {
+ columns[i].finishWrite();
+ }
+ }
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
index 336d33771b..0d74b7d4ac 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
@@ -46,13 +46,14 @@ import
org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters;
+import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
+import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVectorShim;
-import org.apache.spark.sql.types.ArrayType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.*;
import org.apache.spark.sql.utils.SparkArrowUtil;
import org.apache.spark.sql.utils.SparkSchemaUtil;
import org.apache.spark.sql.vectorized.ColumnarArray;
@@ -307,6 +308,12 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
return new DateWriter((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroVector || vector instanceof
TimeStampMicroTZVector) {
return new TimestampMicroWriter((TimeStampVector) vector);
+ } else if (vector instanceof MapVector) {
+ MapVector mapVector = (MapVector) vector;
+ StructVector entries = (StructVector) mapVector.getDataVector();
+ ArrowVectorWriter keyWriter =
createVectorWriter(entries.getChild(MapVector.KEY_NAME));
+ ArrowVectorWriter valueWriter =
createVectorWriter(entries.getChild(MapVector.VALUE_NAME));
+ return new MapWriter(mapVector, entries, keyWriter, valueWriter);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
ArrowVectorWriter elementVector =
createVectorWriter(listVector.getDataVector());
@@ -742,6 +749,14 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
return accessor.getBinary(rowId);
}
+ public void write(SpecializedGetters input, int ordinal) {
+ writer.write(input, ordinal);
+ }
+
+ public void finishWrite() {
+ writer.finish();
+ }
+
private abstract static class ArrowVectorAccessor {
private final ValueVector vector;
@@ -1394,6 +1409,28 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
void setBytes(int rowId, BigDecimal value) {
throw new UnsupportedOperationException();
}
+
+ protected int count = 0;
+
+ abstract void setValueNullSafe(SpecializedGetters input, int ordinal);
+
+ void write(SpecializedGetters input, int ordinal) {
+ if (input.isNullAt(ordinal)) {
+ setNull(count);
+ } else {
+ setValueNullSafe(input, ordinal);
+ }
+ count = count + 1;
+ }
+
+ void finish() {
+ vector.setValueCount(count);
+ }
+
+ void reset() {
+ vector.reset();
+ count = 0;
+ }
}
private static class BooleanWriter extends ArrowVectorWriter {
@@ -1427,6 +1464,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, value ? 1 : 0);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ this.setBoolean(count, input.getBoolean(ordinal));
+ }
}
private static class ByteWriter extends ArrowVectorWriter {
@@ -1467,6 +1509,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, src[srcIndex + i]);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ this.setByte(count, input.getByte(ordinal));
+ }
}
private static class ShortWriter extends ArrowVectorWriter {
@@ -1514,6 +1561,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, src[srcIndex + i]);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ this.setShort(count, input.getShort(ordinal));
+ }
}
private static class IntWriter extends ArrowVectorWriter {
@@ -1574,6 +1626,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, tmp);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setInt(count, input.getInt(ordinal));
+ }
}
private static class LongWriter extends ArrowVectorWriter {
@@ -1629,6 +1686,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId, val);
}
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setLong(count, input.getLong(ordinal));
+ }
+
@Override
void setLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
@@ -1680,6 +1742,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, src[srcIndex + i]);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setFloat(count, input.getFloat(ordinal));
+ }
}
private static class DoubleWriter extends ArrowVectorWriter {
@@ -1720,14 +1787,26 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId + i, src[srcIndex + i]);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setDouble(count, input.getDouble(ordinal));
+ }
}
private static class DecimalWriter extends ArrowVectorWriter {
private final DecimalVector writer;
+ private final int precision;
+ private final int scale;
DecimalWriter(DecimalVector vector) {
super(vector);
this.writer = vector;
+
+ DataType dataType = SparkArrowUtil.fromArrowField(vector.getField());
+ DecimalType decimalType = (DecimalType) dataType;
+ this.precision = decimalType.precision();
+ this.scale = decimalType.scale();
}
@Override
@@ -1757,6 +1836,16 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
}
throw new UnsupportedOperationException();
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ Decimal decimal = input.getDecimal(ordinal, precision, scale);
+ if (decimal.changePrecision(precision, scale)) {
+ setBytes(count, decimal.toJavaBigDecimal());
+ } else {
+ setNull(count);
+ }
+ }
}
private static class StringWriter extends ArrowVectorWriter {
@@ -1791,6 +1880,12 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setSafe(rowId, value, offset, length);
rowId++;
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ UTF8String value = input.getUTF8String(ordinal);
+ setBytes(count, value.numBytes(), value.getBytes(), 0);
+ }
}
private static class BinaryWriter extends ArrowVectorWriter {
@@ -1817,6 +1912,12 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
final void setBytes(int rowId, int count, byte[] src, int srcIndex) {
writer.setSafe(rowId, src, srcIndex, count);
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ byte[] value = input.getBinary(ordinal);
+ setBytes(count, value.length, value, 0);
+ }
}
private static class DateWriter extends ArrowVectorWriter {
@@ -1850,6 +1951,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setNull(rowId + i);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setInt(count, input.getInt(ordinal));
+ }
}
private static class TimestampMicroWriter extends ArrowVectorWriter {
@@ -1883,14 +1989,21 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
writer.setNull(rowId + i);
}
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setLong(count, input.getLong(ordinal));
+ }
}
private static class ArrayWriter extends ArrowVectorWriter {
private final ListVector writer;
+ private final ArrowVectorWriter elementWriter;
ArrayWriter(ListVector vector, ArrowVectorWriter elementVector) {
super(vector);
this.writer = vector;
+ this.elementWriter = elementVector;
}
@Override
@@ -1905,18 +2018,46 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
final void setNull(int rowId) {
writer.setNull(rowId);
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ ArrayData arrayData = input.getArray(ordinal);
+ writer.startNewValue(count);
+ for (int i = 0; i < arrayData.numElements(); ++i) {
+ elementWriter.write(arrayData, i);
+ }
+ writer.endValue(count, arrayData.numElements());
+ }
+
+ @Override
+ void finish() {
+ super.finish();
+ elementWriter.finish();
+ }
+
+ @Override
+ void reset() {
+ super.reset();
+ elementWriter.reset();
+ }
}
private static class StructWriter extends ArrowVectorWriter {
private final StructVector writer;
+ private final ArrowVectorWriter[] childrenWriter;
- StructWriter(StructVector vector, ArrowVectorWriter[] children) {
+ StructWriter(StructVector vector, ArrowVectorWriter[] childrenWriter) {
super(vector);
this.writer = vector;
+ this.childrenWriter = childrenWriter;
}
@Override
void setNull(int rowId) {
+ for (int i = 0; i < childrenWriter.length; ++i) {
+ childrenWriter[i].setNull(rowId);
+ childrenWriter[i].count += 1;
+ }
writer.setNull(rowId);
}
@@ -1924,6 +2065,77 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
void setNotNull(int rowId) {
writer.setIndexDefined(rowId);
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ InternalRow struct = input.getStruct(ordinal, childrenWriter.length);
+ writer.setIndexDefined(count);
+ for (int i = 0; i < struct.numFields(); ++i) {
+ childrenWriter[i].write(struct, i);
+ }
+ }
+
+ @Override
+ void finish() {
+ super.finish();
+ Arrays.stream(childrenWriter).forEach(c -> c.finish());
+ }
+
+ @Override
+ void reset() {
+ super.reset();
+ Arrays.stream(childrenWriter).forEach(c -> c.reset());
+ }
+ }
+
+ private static class MapWriter extends ArrowVectorWriter {
+ private final MapVector writer;
+ private StructVector structVector;
+ private final ArrowVectorWriter keyWriter;
+ private final ArrowVectorWriter valueWriter;
+
+ MapWriter(
+ MapVector mapVector,
+ StructVector structVector,
+ ArrowVectorWriter mapWriter,
+ ArrowVectorWriter valueWriter) {
+ super(mapVector);
+ this.writer = mapVector;
+ this.structVector = structVector;
+ this.keyWriter = mapWriter;
+ this.valueWriter = valueWriter;
+ }
+
+ @Override
+ void setNull(int rowId) {}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ MapData mapData = input.getMap(ordinal);
+ writer.startNewValue(count);
+ ArrayData keys = mapData.keyArray();
+ ArrayData values = mapData.valueArray();
+ for (int i = 0; i < mapData.numElements(); ++i) {
+ structVector.setIndexDefined(i);
+ keyWriter.write(keys, i);
+ valueWriter.write(values, i);
+ }
+ writer.endValue(count, mapData.numElements());
+ }
+
+ @Override
+ void finish() {
+ super.finish();
+ keyWriter.finish();
+ valueWriter.finish();
+ }
+
+ @Override
+ void reset() {
+ super.reset();
+ keyWriter.reset();
+ valueWriter.reset();
+ }
}
private static class NullWriter extends ArrowVectorWriter {
@@ -1938,5 +2150,10 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
void setNull(int rowId) {
writer.setValueCount(writer.getValueCount() + 1);
}
+
+ @Override
+ void setValueNullSafe(SpecializedGetters input, int ordinal) {
+ setNull(count);
+ }
}
}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
index 33a3802f88..7d073f3b20 100644
---
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/InterpretedArrowProjection.scala
@@ -19,12 +19,10 @@ package org.apache.gluten.expression
import org.apache.gluten.vectorized.ArrowColumnarRow
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{BinaryType, DataType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
/**
* A [[ArrowProjection]] that is calculated by calling `eval` on each of the
specified expressions.
@@ -57,43 +55,29 @@ class InterpretedArrowProjection(expressions:
Seq[Expression]) extends ArrowProj
this
}
- private def getStringWriter(ordinal: Int, dt: DataType): (ArrowColumnarRow,
Any) => Unit =
- dt match {
- case StringType => (input, v) => input.setUTF8String(ordinal,
v.asInstanceOf[UTF8String])
- case BinaryType => (input, v) => input.setBinary(ordinal,
v.asInstanceOf[Array[Byte]])
- case _ => (input, v) => input.update(ordinal, v)
- }
+ /** Number of (top level) fields in the resulting row. */
+ private[this] val numFields = validExprs.length
+
+ /** Array that expression results. */
+ private[this] val values = new Array[Any](numFields)
- private[this] val fieldWriters: Array[Any => Unit] = validExprs.map {
- case (e, i) =>
- val writer = if (e.dataType.isInstanceOf[StringType] ||
e.dataType.isInstanceOf[BinaryType]) {
- getStringWriter(i, e.dataType)
- } else InternalRow.getWriter(i, e.dataType)
- if (!e.nullable) { (v: Any) => writer(mutableRow, v) }
- else {
- (v: Any) =>
- {
- if (v == null) {
- mutableRow.setNullAt(i)
- } else {
- writer(mutableRow, v)
- }
- }
- }
- }.toArray
+ /** The row representing the expression results. */
+ private[this] val intermediate = new GenericInternalRow(values)
override def apply(input: InternalRow): ArrowColumnarRow = {
if (subExprEliminationEnabled) {
runtime.setInput(input)
}
+ // Put the expression results in the intermediate row.
var i = 0
- while (i < validExprs.length) {
+ while (i < numFields) {
val (_, ordinal) = validExprs(i)
- // Store the result into buffer first, to make the projection atomic
(needed by aggregation)
- fieldWriters(i)(exprs(ordinal).eval(input))
+ values(i) = exprs(ordinal).eval(input)
i += 1
}
+
+ mutableRow.writeRow(intermediate)
mutableRow
}
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
deleted file mode 100644
index 5ee5cdf202..0000000000
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.hive.execution
-
-import org.apache.gluten.execution.CustomerUDF
-
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.internal.config
-import org.apache.spark.internal.config.UI.UI_ENABLED
-import org.apache.spark.sql.{GlutenTestsBaseTrait, QueryTest, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
-import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
-import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.hive.test.TestHiveContext
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
-import org.apache.spark.sql.test.SQLTestUtils
-
-import org.scalatest.BeforeAndAfterAll
-
-import java.io.File
-
-trait GlutenTestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll {
- override protected val enableAutoThreadAudit = false
-
-}
-
-object GlutenTestHive
- extends TestHiveContext(
- new SparkContext(
- System.getProperty("spark.sql.test.master", "local[1]"),
- "TestSQLContext",
- new SparkConf()
- .set("spark.sql.test", "")
- .set(SQLConf.CODEGEN_FALLBACK.key, "false")
- .set(SQLConf.CODEGEN_FACTORY_MODE.key,
CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
- .set(
- HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key,
- "org.apache.spark.sql.hive.execution.PairSerDe")
- .set(WAREHOUSE_PATH.key,
TestHiveContext.makeWarehouseDir().toURI.getPath)
- // SPARK-8910
- .set(UI_ENABLED, false)
- .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
- // Hive changed the default of
hive.metastore.disallow.incompatible.col.type.changes
- // from false to true. For details, see the JIRA HIVE-12320 and
HIVE-17764.
-
.set("spark.hadoop.hive.metastore.disallow.incompatible.col.type.changes",
"false")
- .set("spark.driver.memory", "1G")
- .set("spark.sql.adaptive.enabled", "true")
- .set("spark.sql.shuffle.partitions", "1")
- .set("spark.sql.files.maxPartitionBytes", "134217728")
- .set("spark.memory.offHeap.enabled", "true")
- .set("spark.memory.offHeap.size", "1024MB")
- .set("spark.plugins", "org.apache.gluten.GlutenPlugin")
- .set("spark.shuffle.manager",
"org.apache.spark.shuffle.sort.ColumnarShuffleManager")
- // Disable ConvertToLocalRelation for better test coverage. Test cases
built on
- // LocalRelation will exercise the optimization rules better by
disabling it as
- // this rule may potentially block testing of other optimization rules
such as
- // ConstantPropagation etc.
- .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
ConvertToLocalRelation.ruleName)
- ),
- false
- ) {}
-
-class GlutenHiveUDFSuite
- extends QueryTest
- with GlutenTestHiveSingleton
- with SQLTestUtils
- with GlutenTestsBaseTrait {
- override protected lazy val spark: SparkSession = GlutenTestHive.sparkSession
- protected lazy val hiveContext: TestHiveContext = GlutenTestHive
- protected lazy val hiveClient: HiveClient =
-
spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client
-
- override protected def beforeAll(): Unit = {
- super.beforeAll()
- val table = "lineitem"
- val tableDir =
- getClass.getResource("").getPath + "/../../../../../../../../../../../" +
- "/backends-velox/src/test/resources/tpch-data-parquet/"
- val tablePath = new File(tableDir, table).getAbsolutePath
- val tableDF = spark.read.format("parquet").load(tablePath)
- tableDF.createOrReplaceTempView(table)
- }
-
- override protected def afterAll(): Unit = {
- try {
- hiveContext.reset()
- } finally {
- super.afterAll()
- }
- }
-
- override protected def shouldRun(testName: String): Boolean = {
- false
- }
-
- test("customer udf") {
- sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
- val df = spark.sql("""select testUDF(l_comment)
- | from lineitem""".stripMargin)
- df.show()
- print(df.queryExecution.executedPlan)
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
- }
-
- test("customer udf wrapped in function") {
- sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
- val df = spark.sql("""select hash(testUDF(l_comment))
- | from lineitem""".stripMargin)
- df.show()
- print(df.queryExecution.executedPlan)
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
- }
-
- test("example") {
- spark.sql("CREATE TEMPORARY FUNCTION testUDF AS
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
- spark.sql("select testUDF('l_commen', 1, 5)").show()
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
- }
-
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]