This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ae34a913 [GLUTEN-5643] Fix the failure when the pre-project of 
GenerateExec falls back (#6167)
9ae34a913 is described below

commit 9ae34a91379fc833c3db873292e53e589be9e62b
Author: Rong Ma <[email protected]>
AuthorDate: Wed Jun 26 15:17:12 2024 +0800

    [GLUTEN-5643] Fix the failure when the pre-project of GenerateExec falls 
back (#6167)
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |  4 +-
 .../apache/gluten/expression/DummyExpression.scala | 77 ++++++++++++++++++++++
 .../apache/spark/sql/expression/UDFResolver.scala  |  5 +-
 .../org/apache/gluten/execution/TestOperator.scala | 24 ++++++-
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        | 38 +++++++----
 .../gluten/expression/ExpressionConverter.scala    |  2 +-
 6 files changed, 131 insertions(+), 19 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 7b8d523a6..b48da1568 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -852,7 +852,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
       Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
       Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
       Sig[TransformKeys](TRANSFORM_KEYS),
-      Sig[TransformValues](TRANSFORM_VALUES)
+      Sig[TransformValues](TRANSFORM_VALUES),
+      // For test purpose.
+      Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION)
     )
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
new file mode 100644
index 000000000..e2af66b59
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, 
UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.types.DataType
+
+abstract class DummyExpression(child: Expression) extends UnaryExpression with 
Serializable {
+  private val accessor: (InternalRow, Int) => Any = 
InternalRow.getAccessor(dataType, nullable)
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(ctx, ev, c => c)
+
+  override def dataType: DataType = child.dataType
+
+  override def eval(input: InternalRow): Any = {
+    assert(input.numFields == 1, "The input row of DummyExpression should have 
only 1 field.")
+    accessor(input, 0)
+  }
+}
+
+// Can be used as a wrapper to force fall back the original expression to mock 
the fallback behavior
+// of an supported expression in Gluten which fails native validation.
+case class VeloxDummyExpression(child: Expression)
+  extends DummyExpression(child)
+  with Transformable {
+  override def getTransformer(
+      childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer 
= {
+    if (childrenTransformers.size != children.size) {
+      throw new IllegalStateException(
+        this.getClass.getSimpleName +
+          ": getTransformer called before children transformer initialized.")
+    }
+
+    GenericExpressionTransformer(
+      VeloxDummyExpression.VELOX_DUMMY_EXPRESSION,
+      childrenTransformers,
+      this)
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression = copy(newChild)
+}
+
+object VeloxDummyExpression {
+  val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression"
+
+  private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION)
+
+  def registerFunctions(registry: FunctionRegistry): Unit = {
+    registry.registerFunction(
+      identifier,
+      new ExpressionInfo(classOf[VeloxDummyExpression].getName, 
VELOX_DUMMY_EXPRESSION),
+      (e: Seq[Expression]) => VeloxDummyExpression(e.head)
+    )
+  }
+
+  def unregisterFunctions(registry: FunctionRegistry): Unit = {
+    registry.dropFunction(identifier)
+  }
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index 915fc5545..e45e8b6fa 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ExpressionInfo, Unevaluable}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -94,7 +94,8 @@ case class UDFExpression(
     dataType: DataType,
     nullable: Boolean,
     children: Seq[Expression])
-  extends Transformable {
+  extends Unevaluable
+  with Transformable {
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): Expression = {
     this.copy(children = newChildren)
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index a892b6f31..9b47a519c 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.execution
 import org.apache.gluten.GlutenConfig
 import org.apache.gluten.datasource.ArrowCSVFileFormat
 import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec
+import org.apache.gluten.expression.VeloxDummyExpression
 import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.SparkConf
@@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
with AdaptiveSparkPla
   override def beforeAll(): Unit = {
     super.beforeAll()
     createTPCHNotNullTables()
+    VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry)
+  }
+
+  override def afterAll(): Unit = {
+    
VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry)
+    super.afterAll()
   }
 
   override protected def sparkConf: SparkConf = {
@@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
with AdaptiveSparkPla
 
   test("select_part_column") {
     val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem 
limit 1") {
-      df => { assert(df.schema.fields.length == 2) }
+      df =>
+        {
+          assert(df.schema.fields.length == 2)
+        }
     }
     checkLengthAndPlan(df, 1)
   }
 
   test("select_as") {
     val df = runQueryAndCompare("select l_shipdate as my_col from lineitem 
limit 1") {
-      df => { assert(df.schema.fieldNames(0).equals("my_col")) }
+      df =>
+        {
+          assert(df.schema.fieldNames(0).equals("my_col"))
+        }
     }
     checkLengthAndPlan(df, 1)
   }
@@ -1074,6 +1087,13 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite with AdaptiveSparkPla
             // No ProjectExecTransformer is introduced.
             checkSparkOperatorChainMatch[GenerateExecTransformer, 
FilterExecTransformer]
           }
+
+          runQueryAndCompare(
+            s"""
+               |SELECT 
$func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2;
+               |""".stripMargin) {
+            checkGlutenOperatorMatch[GenerateExecTransformer]
+          }
         }
     }
   }
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 8b8a92624..73047b2f4 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -26,6 +26,7 @@
 #include "utils/ConfigExtractor.h"
 
 #include "config/GlutenConfig.h"
+#include "operators/plannodes/RowVectorStream.h"
 
 namespace gluten {
 namespace {
@@ -710,16 +711,23 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
 namespace {
 
 void extractUnnestFieldExpr(
-    std::shared_ptr<const core::ProjectNode> projNode,
+    std::shared_ptr<const core::PlanNode> child,
     int32_t index,
     std::vector<core::FieldAccessTypedExprPtr>& unnestFields) {
-  auto name = projNode->names()[index];
-  auto expr = projNode->projections()[index];
-  auto type = expr->type();
+  if (auto projNode = std::dynamic_pointer_cast<const 
core::ProjectNode>(child)) {
+    auto name = projNode->names()[index];
+    auto expr = projNode->projections()[index];
+    auto type = expr->type();
 
-  auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type, 
name);
-  VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only 
support field");
-  unnestFields.emplace_back(unnestFieldExpr);
+    auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(type, 
name);
+    VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only 
support field");
+    unnestFields.emplace_back(unnestFieldExpr);
+  } else {
+    auto name = child->outputType()->names()[index];
+    auto field = child->outputType()->childAt(index);
+    auto unnestFieldExpr = std::make_shared<core::FieldAccessTypedExpr>(field, 
name);
+    unnestFields.emplace_back(unnestFieldExpr);
+  }
 }
 
 } // namespace
@@ -752,10 +760,13 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
       
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), 
"injectedProject=");
 
   if (injectedProject) {
-    auto projNode = std::dynamic_pointer_cast<const 
core::ProjectNode>(childNode);
+    // Child should be either ProjectNode or ValueStreamNode in case of 
project fallback.
     VELOX_CHECK(
-        projNode != nullptr && projNode->names().size() > 
requiredChildOutput.size(),
-        "injectedProject is true, but the Project is missing or does not have 
the corresponding projection field")
+        (std::dynamic_pointer_cast<const core::ProjectNode>(childNode) != 
nullptr ||
+         std::dynamic_pointer_cast<const ValueStreamNode>(childNode) != 
nullptr) &&
+            childNode->outputType()->size() > requiredChildOutput.size(),
+        "injectedProject is true, but the ProjectNode or ValueStreamNode (in 
case of projection fallback)"
+        " is missing or does not have the corresponding projection field")
 
     bool isStack = generateRel.has_advanced_extension() &&
         
SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), 
"isStack=");
@@ -768,7 +779,8 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
       //  +- Project [fake_column#128, [1,2,3] AS _pre_0#129]
       //   +- RewrittenNodeWall Scan OneRowRelation[fake_column#128]
       // The last projection column in GeneratorRel's child(Project) is the 
column we need to unnest
-      extractUnnestFieldExpr(projNode, projNode->projections().size() - 1, 
unnest);
+      auto index = childNode->outputType()->size() - 1;
+      extractUnnestFieldExpr(childNode, index, unnest);
     } else {
       // For stack function, e.g. stack(2, 1,2,3), a sample
       // input substrait plan is like the following:
@@ -782,10 +794,10 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
       auto generatorFunc = generator.scalar_function();
       auto numRows = 
SubstraitParser::getLiteralValue<int32_t>(generatorFunc.arguments(0).value().literal());
       auto numFields = 
static_cast<int32_t>(std::ceil((generatorFunc.arguments_size() - 1.0) / 
numRows));
-      auto totalProjectCount = projNode->names().size();
+      auto totalProjectCount = childNode->outputType()->size();
 
       for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i) 
{
-        extractUnnestFieldExpr(projNode, i, unnest);
+        extractUnnestFieldExpr(childNode, i, unnest);
       }
     }
   } else {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index b7b0889dc..da5625cd4 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
-trait Transformable extends Unevaluable {
+trait Transformable {
   def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): 
ExpressionTransformer
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to