This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9ae34a913 [GLUTEN-5643] Fix the failure when the pre-project of
GenerateExec falls back (#6167)
9ae34a913 is described below
commit 9ae34a91379fc833c3db873292e53e589be9e62b
Author: Rong Ma <[email protected]>
AuthorDate: Wed Jun 26 15:17:12 2024 +0800
[GLUTEN-5643] Fix the failure when the pre-project of GenerateExec falls
back (#6167)
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 4 +-
.../apache/gluten/expression/DummyExpression.scala | 77 ++++++++++++++++++++++
.../apache/spark/sql/expression/UDFResolver.scala | 5 +-
.../org/apache/gluten/execution/TestOperator.scala | 24 ++++++-
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 38 +++++++----
.../gluten/expression/ExpressionConverter.scala | 2 +-
6 files changed, 131 insertions(+), 19 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 7b8d523a6..b48da1568 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -852,7 +852,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
- Sig[TransformValues](TRANSFORM_VALUES)
+ Sig[TransformValues](TRANSFORM_VALUES),
+ // For test purpose.
+ Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
)
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
new file mode 100644
index 000000000..e2af66b59
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo,
UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.types.DataType
+
+abstract class DummyExpression(child: Expression) extends UnaryExpression with
Serializable {
+ private val accessor: (InternalRow, Int) => Any =
InternalRow.getAccessor(dataType, nullable)
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
+ defineCodeGen(ctx, ev, c => c)
+
+ override def dataType: DataType = child.dataType
+
+ override def eval(input: InternalRow): Any = {
+ assert(input.numFields == 1, "The input row of DummyExpression should have
only 1 field.")
+ accessor(input, 0)
+ }
+}
+
+// Can be used as a wrapper to force fall back the original expression to mock
the fallback behavior
+// of an supported expression in Gluten which fails native validation.
+case class VeloxDummyExpression(child: Expression)
+ extends DummyExpression(child)
+ with Transformable {
+ override def getTransformer(
+ childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer
= {
+ if (childrenTransformers.size != children.size) {
+ throw new IllegalStateException(
+ this.getClass.getSimpleName +
+ ": getTransformer called before children transformer initialized.")
+ }
+
+ GenericExpressionTransformer(
+ VeloxDummyExpression.VELOX_DUMMY_EXPRESSION,
+ childrenTransformers,
+ this)
+ }
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression = copy(newChild)
+}
+
+object VeloxDummyExpression {
+ val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression"
+
+ private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION)
+
+ def registerFunctions(registry: FunctionRegistry): Unit = {
+ registry.registerFunction(
+ identifier,
+ new ExpressionInfo(classOf[VeloxDummyExpression].getName,
VELOX_DUMMY_EXPRESSION),
+ (e: Seq[Expression]) => VeloxDummyExpression(e.head)
+ )
+ }
+
+ def unregisterFunctions(registry: FunctionRegistry): Unit = {
+ registry.dropFunction(identifier)
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index 915fc5545..e45e8b6fa 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, ExpressionInfo, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -94,7 +94,8 @@ case class UDFExpression(
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
- extends Transformable {
+ extends Unevaluable
+ with Transformable {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index a892b6f31..9b47a519c 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
import org.apache.gluten.datasource.ArrowCSVFileFormat
import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
+import org.apache.gluten.expression.VeloxDummyExpression
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
@@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPla
override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
+ VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry)
+ }
+
+ override def afterAll(): Unit = {
+
VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry)
+ super.afterAll()
}
override protected def sparkConf: SparkConf = {
@@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPla
test("select_part_column") {
val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem
limit 1") {
- df => { assert(df.schema.fields.length == 2) }
+ df =>
+ {
+ assert(df.schema.fields.length == 2)
+ }
}
checkLengthAndPlan(df, 1)
}
test("select_as") {
val df = runQueryAndCompare("select l_shipdate as my_col from lineitem
limit 1") {
- df => { assert(df.schema.fieldNames(0).equals("my_col")) }
+ df =>
+ {
+ assert(df.schema.fieldNames(0).equals("my_col"))
+ }
}
checkLengthAndPlan(df, 1)
}
@@ -1074,6 +1087,13 @@ class TestOperator extends
VeloxWholeStageTransformerSuite with AdaptiveSparkPla
// No ProjectExecTransformer is introduced.
checkSparkOperatorChainMatch[GenerateExecTransformer,
FilterExecTransformer]
}
+
+ runQueryAndCompare(
+ s"""
+ |SELECT
$func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2;
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[GenerateExecTransformer]
+ }
}
}
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 8b8a92624..73047b2f4 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -26,6 +26,7 @@
#include "utils/ConfigExtractor.h"
#include "config/GlutenConfig.h"
+#include "operators/plannodes/RowVectorStream.h"
namespace gluten {
namespace {
@@ -710,16 +711,23 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
namespace {
void extractUnnestFieldExpr(
- std::shared_ptr<const core::ProjectNode> projNode,
+ std::shared_ptr<const core::PlanNode> child,
int32_t index,
std::vector<core::FieldAccessTypedExprPtr>& unnestFields) {
- auto name = projNode->names()[index];
- auto expr = projNode->projections()[index];
- auto type = expr->type();
+ if (auto projNode = std::dynamic_pointer_cast<const
core::ProjectNode>(child)) {
+ auto name = projNode->names()[index];
+ auto expr = projNode->projections()[index];
+ auto type = expr->type();
- auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type,
name);
- VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only
support field");
- unnestFields.emplace_back(unnestFieldExpr);
+ auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type,
name);
+ VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only
support field");
+ unnestFields.emplace_back(unnestFieldExpr);
+ } else {
+ auto name = child->outputType()->names()[index];
+ auto field = child->outputType()->childAt(index);
+ auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(field,
name);
+ unnestFields.emplace_back(unnestFieldExpr);
+ }
}
} // namespace
@@ -752,10 +760,13 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"injectedProject=");
if (injectedProject) {
- auto projNode = std::dynamic_pointer_cast<const
core::ProjectNode>(childNode);
+ // Child should be either ProjectNode or ValueStreamNode in case of
project fallback.
VELOX_CHECK(
- projNode != nullptr && projNode->names().size() >
requiredChildOutput.size(),
- "injectedProject is true, but the Project is missing or does not have
the corresponding projection field")
+ (std::dynamic_pointer_cast<const core::ProjectNode>(childNode) !=
nullptr ||
+ std::dynamic_pointer_cast<const ValueStreamNode>(childNode) !=
nullptr) &&
+ childNode->outputType()->size() > requiredChildOutput.size(),
+ "injectedProject is true, but the ProjectNode or ValueStreamNode (in
case of projection fallback)"
+ " is missing or does not have the corresponding projection field")
bool isStack = generateRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(),
"isStack=");
@@ -768,7 +779,8 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
// +- Project [fake_column#128, [1,2,3] AS _pre_0#129]
// +- RewrittenNodeWall Scan OneRowRelation[fake_column#128]
// The last projection column in GeneratorRel's child(Project) is the
column we need to unnest
- extractUnnestFieldExpr(projNode, projNode->projections().size() - 1,
unnest);
+ auto index = childNode->outputType()->size() - 1;
+ extractUnnestFieldExpr(childNode, index, unnest);
} else {
// For stack function, e.g. stack(2, 1,2,3), a sample
// input substrait plan is like the following:
@@ -782,10 +794,10 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
auto generatorFunc = generator.scalar_function();
auto numRows =
SubstraitParser::getLiteralValue<int32_t>(generatorFunc.arguments(0).value().literal());
auto numFields =
static_cast<int32_t>(std::ceil((generatorFunc.arguments_size() - 1.0) /
numRows));
- auto totalProjectCount = projNode->names().size();
+ auto totalProjectCount = childNode->outputType()->size();
for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i)
{
- extractUnnestFieldExpr(projNode, i, unnest);
+ extractUnnestFieldExpr(childNode, i, unnest);
}
}
} else {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index b7b0889dc..da5625cd4 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-trait Transformable extends Unevaluable {
+trait Transformable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]):
ExpressionTransformer
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]