This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new ef011c0ed3 [GLUTEN-9313][VL] ColumnarPartialProject supports built-in 
but blacklisted function (#9315)
ef011c0ed3 is described below

commit ef011c0ed394bdd0da5126646fcb557c00e6457d
Author: WangGuangxin <[email protected]>
AuthorDate: Thu Apr 17 20:07:15 2025 +0800

    [GLUTEN-9313][VL] ColumnarPartialProject supports built-in but blacklisted 
function (#9315)
---
 .../execution/ColumnarPartialProjectExec.scala     | 81 ++++++++++++----------
 .../gluten/execution/MiscOperatorSuite.scala       | 12 ++++
 .../gluten/expression/ExpressionMappings.scala     | 14 +++-
 3 files changed, 67 insertions(+), 40 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 562de034a9..88742b0833 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.execution
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
 import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils}
+import org.apache.gluten.expression.{ArrowProjection, ExpressionMappings, 
ExpressionUtils}
 import org.apache.gluten.extension.ValidationResult
 import org.apache.gluten.extension.columnar.transition.Convention
 import org.apache.gluten.iterator.Iterators
@@ -52,15 +52,15 @@ import scala.collection.mutable.ListBuffer
  *   child plan
  */
 case class ColumnarPartialProjectExec(projectList: Seq[NamedExpression], 
child: SparkPlan)(
-    replacedAliasUdf: Seq[Alias])
+    replacedAlias: Seq[Alias])
   extends UnaryExecNode
   with ValidatablePlan {
 
   private val projectAttributes: ListBuffer[Attribute] = ListBuffer()
   private val projectIndexInChild: ListBuffer[Int] = ListBuffer()
-  private var UDFAttrNotExists = false
+  private var attrNotExists = false
   private var hasUnsupportedDataType = false
-  getProjectIndexInChildOutput(replacedAliasUdf)
+  getProjectIndexInChildOutput(replacedAlias)
 
   @transient override lazy val metrics = Map(
     "time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of 
partial project"),
@@ -72,14 +72,13 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
       "time of Arrow ColumnarBatch to velox")
   )
 
-  override def output: Seq[Attribute] = child.output ++ 
replacedAliasUdf.map(_.toAttribute)
+  override def output: Seq[Attribute] = child.output ++ 
replacedAlias.map(_.toAttribute)
 
   override def doCanonicalize(): ColumnarPartialProjectExec = {
     super
       .doCanonicalize()
       .asInstanceOf[ColumnarPartialProjectExec]
-      .copy()(replacedAliasUdf =
-        replacedAliasUdf.map(QueryPlan.normalizeExpressions(_, child.output)))
+      .copy()(replacedAlias = 
replacedAlias.map(QueryPlan.normalizeExpressions(_, child.output)))
   }
 
   override def batchType(): Convention.BatchType = 
BackendsApiManager.getSettings.primaryBatchType
@@ -92,7 +91,7 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
   }
 
   final override protected def otherCopyArgs: Seq[AnyRef] = {
-    replacedAliasUdf :: Nil
+    replacedAlias :: Nil
   }
 
   private def validateExpression(expr: Expression): Boolean = {
@@ -106,7 +105,7 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
         val index = child.output.indexWhere(s => s.exprId.equals(a.exprId))
         // Some child operator as HashAggregateTransformer will not have udf 
child column
         if (index < 0) {
-          UDFAttrNotExists = true
+          attrNotExists = true
           log.debug(s"Expression $a should exist in child output 
${child.output}")
           false
         } else if (
@@ -127,19 +126,22 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
   }
 
   override protected def doValidateInternal(): ValidationResult = {
-    if (UDFAttrNotExists) {
-      return ValidationResult.failed("Attribute in the UDF does not exists in 
its child")
+    if (attrNotExists) {
+      return ValidationResult.failed(
+        "Attribute in the partial projected expressions does not exists in its 
child")
     }
     if (hasUnsupportedDataType) {
-      return ValidationResult.failed("Attribute in the UDF contains 
unsupported type")
+      return ValidationResult.failed(
+        "Attribute in the partial projected expressions contains unsupported 
type")
     }
     if (projectAttributes.size == child.output.size) {
-      return ValidationResult.failed("UDF need all the columns in child 
output")
+      return ValidationResult.failed(
+        "The partial projected expressions need all the columns in child 
output")
     }
-    if (replacedAliasUdf.isEmpty) {
-      return ValidationResult.failed("No UDF")
+    if (replacedAlias.isEmpty) {
+      return ValidationResult.failed("No UDF or blacklisted expressions")
     }
-    if (replacedAliasUdf.size > projectList.size) {
+    if (replacedAlias.size > projectList.size) {
       // e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c
       return ValidationResult.failed("Number of RowToColumn columns is more 
than ProjectExec")
     }
@@ -208,7 +210,7 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
       c2a: SQLMetric,
       a2c: SQLMetric): Iterator[ColumnarBatch] = {
     // select part of child output and child data
-    val proj = ArrowProjection.create(replacedAliasUdf, 
projectAttributes.toSeq)
+    val proj = ArrowProjection.create(replacedAlias, projectAttributes.toSeq)
     val numRows = childData.numRows()
     val start = System.currentTimeMillis()
     val arrowBatch = if (childData.numCols() == 0) {
@@ -219,7 +221,7 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
     c2a += System.currentTimeMillis() - start
 
     val schema =
-      
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
+      
SparkShimLoader.getSparkShims.structFromAttributes(replacedAlias.map(_.toAttribute))
     val vectors: Array[ArrowWritableColumnVector] = ArrowWritableColumnVector
       .allocateColumns(numRows, schema)
       .map {
@@ -253,17 +255,17 @@ case class ColumnarPartialProjectExec(projectList: 
Seq[NamedExpression], child:
        |$formattedNodeName
        |${ExplainUtils.generateFieldString("Output", output)}
        |${ExplainUtils.generateFieldString("Input", child.output)}
-       |${ExplainUtils.generateFieldString("UDF", replacedAliasUdf)}
+       |${ExplainUtils.generateFieldString("UDF", replacedAlias)}
        |${ExplainUtils.generateFieldString("ProjectOutput", projectAttributes)}
        |${ExplainUtils.generateFieldString("ProjectInputIndex", 
projectIndexInChild)}
        |""".stripMargin
   }
 
   override def simpleString(maxFields: Int): String =
-    super.simpleString(maxFields) + " PartialProject " + replacedAliasUdf
+    super.simpleString(maxFields) + " PartialProject " + replacedAlias
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
ColumnarPartialProjectExec = {
-    copy(child = newChild)(replacedAliasUdf)
+    copy(child = newChild)(replacedAlias)
   }
 }
 
@@ -276,12 +278,17 @@ object ColumnarPartialProjectExec {
     HiveUDFTransformer.isHiveUDF(h) && 
!VeloxHiveUDFTransformer.isSupportedHiveUDF(h)
   }
 
-  private def containsUDF(expr: Expression): Boolean = {
+  private def isBlacklistExpression(e: Expression): Boolean = {
+    ExpressionMappings.blacklistExpressionMap.contains(e.getClass)
+  }
+
+  private def containsUDFOrBlacklistExpression(expr: Expression): Boolean = {
     if (expr == null) return false
     expr match {
       case _: ScalaUDF => true
       case h if containsUnsupportedHiveUDF(h) => true
-      case p => p.children.exists(c => containsUDF(c))
+      case e if isBlacklistExpression(e) => true
+      case p => p.children.exists(c => containsUDFOrBlacklistExpression(c))
     }
   }
 
@@ -304,22 +311,22 @@ object ColumnarPartialProjectExec {
     case _ => false
   }
 
-  private def replaceExpressionUDF(
-      expr: Expression,
-      replacedAliasUdf: ListBuffer[Alias]): Expression = {
+  private def replaceExpression(expr: Expression, replacedAlias: 
ListBuffer[Alias]): Expression = {
     if (expr == null) return null
     expr match {
       case u: ScalaUDF =>
-        replaceByAlias(u, replacedAliasUdf)
+        replaceByAlias(u, replacedAlias)
       case h if containsUnsupportedHiveUDF(h) =>
-        replaceByAlias(h, replacedAliasUdf)
+        replaceByAlias(h, replacedAlias)
+      case e if isBlacklistExpression(e) =>
+        replaceByAlias(e, replacedAlias)
       case au @ Alias(_: ScalaUDF, _) =>
-        val replaceIndex = replacedAliasUdf.indexWhere(r => r.exprId == 
au.exprId)
+        val replaceIndex = replacedAlias.indexWhere(r => r.exprId == au.exprId)
         if (replaceIndex == -1) {
-          replacedAliasUdf.append(au)
+          replacedAlias.append(au)
           au.toAttribute
         } else {
-          replacedAliasUdf(replaceIndex).toAttribute
+          replacedAlias(replaceIndex).toAttribute
         }
       // Alias(HiveSimpleUDF) not exists, only be 
Alias(ToPrettyString(HiveSimpleUDF)),
       // so don't process this condition
@@ -330,20 +337,20 @@ object ColumnarPartialProjectExec {
         // else myudf(knownnotnull(cast(l_extendedprice#9 as bigint)))
         // if we extract else branch, and use the data child l_extendedprice,
         // the result is incorrect for null value
-        if (containsUDF(expr)) {
-          replaceByAlias(expr, replacedAliasUdf)
+        if (containsUDFOrBlacklistExpression(expr)) {
+          replaceByAlias(expr, replacedAlias)
         } else expr
-      case p => p.withNewChildren(p.children.map(c => replaceExpressionUDF(c, 
replacedAliasUdf)))
+      case p => p.withNewChildren(p.children.map(c => replaceExpression(c, 
replacedAlias)))
     }
   }
 
   def create(original: ProjectExec): ProjectExecTransformer = {
-    val replacedAliasUdf: ListBuffer[Alias] = ListBuffer()
+    val replacedAlias: ListBuffer[Alias] = ListBuffer()
     val newProjectList = original.projectList.map {
-      p => replaceExpressionUDF(p, 
replacedAliasUdf).asInstanceOf[NamedExpression]
+      p => replaceExpression(p, replacedAlias).asInstanceOf[NamedExpression]
     }
     val partialProject =
-      ColumnarPartialProjectExec(original.projectList, 
original.child)(replacedAliasUdf.toSeq)
+      ColumnarPartialProjectExec(original.projectList, 
original.child)(replacedAlias.toSeq)
     ProjectExecTransformer(newProjectList, partialProject)
   }
 }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index b7d7abe882..3e52ac8655 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -2040,4 +2040,16 @@ class MiscOperatorSuite extends 
VeloxWholeStageTransformerSuite with AdaptiveSpa
       }
     }
   }
+
+  test("Blacklist expression can be handled by ColumnarPartialProject") {
+    withSQLConf("spark.gluten.expression.blacklist" -> "regexp_replace") {
+      runQueryAndCompare(
+        "SELECT c_custkey, c_name, regexp_replace(c_comment, '\\w', 
'something') FROM customer") {
+        df =>
+          val executedPlan = getExecutedPlan(df)
+          assert(executedPlan.count(_.isInstanceOf[ProjectExec]) == 0)
+          
assert(executedPlan.count(_.isInstanceOf[ColumnarPartialProjectExec]) == 1)
+      }
+    }
+  }
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index 51e187e9b9..8c91827413 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -351,12 +351,20 @@ object ExpressionMappings {
     Sig[Right](RIGHT)
   ) ++ SparkShimLoader.getSparkShims.runtimeReplaceableExpressionMappings
 
+  def blacklistExpressionMap: Map[Class[_], String] = {
+    partitionExpressionMapByBlacklist._1
+  }
+
   def expressionsMap: Map[Class[_], String] = {
+    partitionExpressionMapByBlacklist._2
+  }
+
+  private def partitionExpressionMapByBlacklist: (Map[Class[_], String], 
Map[Class[_], String]) = {
     val blacklist = GlutenConfig.get.expressionBlacklist
-    val filtered = (defaultExpressionsMap ++ toMap(
-      
BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)).filterNot(
+    val (blacklistedExpr, filteredExpr) = (defaultExpressionsMap ++ toMap(
+      
BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)).partition(
       kv => blacklist.contains(kv._2))
-    filtered
+    (blacklistedExpr, filteredExpr)
   }
 
   // This is needed when generating function support status documentation for 
Spark built-in


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to