This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b496e86b6 [GLUTEN-7760] Fix udf implicit cast & update doc (#7852)
5b496e86b6 is described below

commit 5b496e86b6e1fbca55866616fa1519a532b07b51
Author: Rong Ma <[email protected]>
AuthorDate: Fri Nov 8 18:44:50 2024 +0800

    [GLUTEN-7760] Fix udf implicit cast & update doc (#7852)
---
 .../spark/sql/hive/VeloxHiveUDFTransformer.scala   |   8 +-
 .../apache/gluten/expression/VeloxUdfSuite.scala   |  27 +++--
 docs/developers/VeloxUDF.md                        |  19 ++--
 .../execution/WholeStageTransformerSuite.scala     | 108 +-------------------
 .../org/apache/spark/sql/GlutenQueryTest.scala     | 109 ++++++++++++++++++++-
 5 files changed, 142 insertions(+), 129 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
index d895faa317..b3524e20f0 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
@@ -37,11 +37,11 @@ object VeloxHiveUDFTransformer {
     }
 
     if (UDFResolver.UDFNames.contains(udfClassName)) {
-      UDFResolver
+      val udfExpression = UDFResolver
         .getUdfExpression(udfClassName, udfName)(expr.children)
-        .getTransformer(
-          ExpressionConverter.replaceWithExpressionTransformer(expr.children, 
attributeSeq)
-        )
+      udfExpression.getTransformer(
+        
ExpressionConverter.replaceWithExpressionTransformer(udfExpression.children, 
attributeSeq)
+      )
     } else {
       HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr, 
attributeSeq)
     }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index f85103deb8..61ba927cd4 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -16,11 +16,13 @@
  */
 package org.apache.gluten.expression
 
+import org.apache.gluten.execution.ProjectExecTransformer
 import org.apache.gluten.tags.{SkipTestTags, UDFTest}
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.execution.ProjectExec
 import org.apache.spark.sql.expression.UDFResolver
 
 import java.nio.file.Paths
@@ -158,16 +160,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with 
SQLHelper {
                       |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
                       |""".stripMargin)
 
-          val nativeResultWithImplicitConversion =
-            spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM 
$tbl""").collect()
-          val nativeResult =
-            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM 
$tbl""").collect()
+          val offloadWithImplicitConversionDF =
+            spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
+          
checkGlutenOperatorMatch[ProjectExecTransformer](offloadWithImplicitConversionDF)
+          val offloadWithImplicitConversionResult = 
offloadWithImplicitConversionDF.collect()
+
+          val offloadDF =
+            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+          checkGlutenOperatorMatch[ProjectExecTransformer](offloadDF)
+          val offloadResult = offloadWithImplicitConversionDF.collect()
+
           // Unregister native hive udf to fallback.
           
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
-          val fallbackResult =
-            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM 
$tbl""").collect()
-          
assert(nativeResultWithImplicitConversion.sameElements(fallbackResult))
-          assert(nativeResult.sameElements(fallbackResult))
+          val fallbackDF =
+            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+          checkSparkOperatorMatch[ProjectExec](fallbackDF)
+          val fallbackResult = fallbackDF.collect()
+          
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
+          assert(offloadResult.sameElements(fallbackResult))
 
           // Add an unimplemented udf to the map to test fallback of 
registered native hive udf.
           
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
@@ -176,6 +186,7 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with 
SQLHelper {
                       |AS 
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
                       |""".stripMargin)
           val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+          checkSparkOperatorMatch[ProjectExec](df)
           checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
         } finally {
           spark.sql(s"DROP TABLE IF EXISTS $tbl")
diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md
index 4f685cc41e..4cbdcfa992 100644
--- a/docs/developers/VeloxUDF.md
+++ b/docs/developers/VeloxUDF.md
@@ -172,22 +172,23 @@ or
 Start `spark-sql` and run query. You need to add jar 
"spark-hive_2.12-<spark.version>-tests.jar" to the classpath for hive udf 
`org.apache.spark.sql.hive.execution.UDFStringString`
 
 ```
+spark-sql (default)> create table tbl as select * from values ('hello');
+Time taken: 3.656 seconds
 spark-sql (default)> CREATE TEMPORARY FUNCTION hive_string_string AS 
'org.apache.spark.sql.hive.execution.UDFStringString';
-Time taken: 0.808 seconds
-spark-sql (default)> select hive_string_string("hello", "world");
+Time taken: 0.047 seconds
+spark-sql (default)> select hive_string_string(col1, 'world') from tbl;
 hello world
-Time taken: 3.208 seconds, Fetched 1 row(s)
+Time taken: 1.217 seconds, Fetched 1 row(s)
 ```
 
 You can verify the offload with "explain".
 ```
-spark-sql (default)> explain select hive_string_string("hello", "world");
-== Physical Plan ==
-VeloxColumnarToRowExec
-+- ^(2) ProjectExecTransformer [hello world AS hive_string_string(hello, 
world)#8]
-   +- ^(2) InputIteratorTransformer[fake_column#9]
+spark-sql (default)> explain select hive_string_string(col1, 'world') from tbl;
+VeloxColumnarToRow
++- ^(2) ProjectExecTransformer 
[HiveSimpleUDF#org.apache.spark.sql.hive.execution.UDFStringString(col1#11,world)
 AS hive_string_string(col1, world)#12]
+   +- ^(2) InputIteratorTransformer[col1#11]
       +- RowToVeloxColumnar
-         +- *(1) Scan OneRowRelation[fake_column#9]
+         +- Scan hive spark_catalog.default.tbl [col1#11], HiveTableRelation 
[`spark_catalog`.`default`.`tbl`, 
org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [col1#11], 
Partition Cols: []]
 ```
 
 ## Configurations
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index 146d6fde58..fd250834d0 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -17,15 +17,13 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.GlutenConfig
-import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.test.FallbackUtil
 import org.apache.gluten.utils.Arm
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
-import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan, 
UnaryExecNode}
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AdaptiveSparkPlanHelper, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.DoubleType
 
@@ -33,7 +31,6 @@ import java.io.File
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.io.Source
-import scala.reflect.ClassTag
 
 case class Table(name: String, partitionColumns: Seq[String])
 
@@ -179,109 +176,6 @@ abstract class WholeStageTransformerSuite
       result
   }
 
-  def checkLengthAndPlan(df: DataFrame, len: Int = 100): Unit = {
-    assert(df.collect().length == len)
-    val executedPlan = getExecutedPlan(df)
-    assert(executedPlan.exists(plan => 
plan.find(_.isInstanceOf[TransformSupport]).isDefined))
-  }
-
-  /**
-   * Get all the children plan of plans.
-   * @param plans:
-   *   the input plans.
-   * @return
-   */
-  def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = {
-    if (plans.isEmpty) {
-      return Seq()
-    }
-
-    val inputPlans: Seq[SparkPlan] = plans.map {
-      case stage: ShuffleQueryStageExec => stage.plan
-      case plan => plan
-    }
-
-    var newChildren: Seq[SparkPlan] = Seq()
-    inputPlans.foreach {
-      plan =>
-        newChildren = newChildren ++ getChildrenPlan(plan.children)
-        // To avoid duplication of WholeStageCodegenXXX and its children.
-        if (!plan.nodeName.startsWith("WholeStageCodegen")) {
-          newChildren = newChildren :+ plan
-        }
-    }
-    newChildren
-  }
-
-  /**
-   * Get the executed plan of a data frame.
-   * @param df:
-   *   dataframe.
-   * @return
-   *   A sequence of executed plans.
-   */
-  def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = {
-    df.queryExecution.executedPlan match {
-      case exec: AdaptiveSparkPlanExec =>
-        getChildrenPlan(Seq(exec.executedPlan))
-      case cmd: CommandResultExec =>
-        getChildrenPlan(Seq(cmd.commandPhysicalPlan))
-      case plan =>
-        getChildrenPlan(Seq(plan))
-    }
-  }
-
-  /**
-   * Check whether the executed plan of a dataframe contains the expected plan.
-   * @param df:
-   *   the input dataframe.
-   * @param tag:
-   *   class of the expected plan.
-   * @tparam T:
-   *   type of the expected plan.
-   */
-  def checkGlutenOperatorMatch[T <: GlutenPlan](df: DataFrame)(implicit tag: 
ClassTag[T]): Unit = {
-    val executedPlan = getExecutedPlan(df)
-    assert(
-      executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)),
-      s"Expect ${tag.runtimeClass.getSimpleName} exists " +
-        s"in executedPlan:\n ${executedPlan.last}"
-    )
-  }
-
-  def checkSparkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag: 
ClassTag[T]): Unit = {
-    val executedPlan = getExecutedPlan(df)
-    assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
-  }
-
-  /**
-   * Check whether the executed plan of a dataframe contains the expected plan 
chain.
-   *
-   * @param df
-   *   : the input dataframe.
-   * @param tag
-   *   : class of the expected plan.
-   * @param childTag
-   *   : class of the expected plan's child.
-   * @tparam T
-   *   : type of the expected plan.
-   * @tparam PT
-   *   : type of the expected plan's child.
-   */
-  def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
-      df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit 
= {
-    val executedPlan = getExecutedPlan(df)
-    assert(
-      executedPlan.exists(
-        plan =>
-          tag.runtimeClass.isInstance(plan)
-            && childTag.runtimeClass.isInstance(plan.children.head)),
-      s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
-        + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
-        + s"${executedPlan.last}"
-    )
-  }
-
   /**
    * run a query with native engine as well as vanilla spark then compare the 
result set for
    * correctness check
diff --git 
a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala 
b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
index 53abaa9ac2..164083a8d8 100644
--- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
+++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
@@ -21,6 +21,8 @@ package org.apache.spark.sql
  *   1. We need to modify the way org.apache.spark.sql.CHQueryTest#compare 
compares double
  */
 import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.execution.TransformSupport
+import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.SPARK_VERSION_SHORT
@@ -28,7 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan, 
SQLExecution, UnaryExecNode}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.storage.StorageLevel
 
@@ -38,6 +41,7 @@ import org.scalatest.Assertions
 import java.util.TimeZone
 
 import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
 import scala.reflect.runtime.universe
 
 abstract class GlutenQueryTest extends PlanTest {
@@ -306,6 +310,109 @@ abstract class GlutenQueryTest extends PlanTest {
       query.queryExecution.executedPlan.missingInput.isEmpty,
       s"The physical plan has missing 
inputs:\n${query.queryExecution.executedPlan}")
   }
+
+  def checkLengthAndPlan(df: DataFrame, len: Int = 100): Unit = {
+    assert(df.collect().length == len)
+    val executedPlan = getExecutedPlan(df)
+    assert(executedPlan.exists(plan => 
plan.find(_.isInstanceOf[TransformSupport]).isDefined))
+  }
+
+  /**
+   * Get all the children plan of plans.
+   * @param plans:
+   *   the input plans.
+   * @return
+   */
+  def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = {
+    if (plans.isEmpty) {
+      return Seq()
+    }
+
+    val inputPlans: Seq[SparkPlan] = plans.map {
+      case stage: ShuffleQueryStageExec => stage.plan
+      case plan => plan
+    }
+
+    var newChildren: Seq[SparkPlan] = Seq()
+    inputPlans.foreach {
+      plan =>
+        newChildren = newChildren ++ getChildrenPlan(plan.children)
+        // To avoid duplication of WholeStageCodegenXXX and its children.
+        if (!plan.nodeName.startsWith("WholeStageCodegen")) {
+          newChildren = newChildren :+ plan
+        }
+    }
+    newChildren
+  }
+
+  /**
+   * Get the executed plan of a data frame.
+   * @param df:
+   *   dataframe.
+   * @return
+   *   A sequence of executed plans.
+   */
+  def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = {
+    df.queryExecution.executedPlan match {
+      case exec: AdaptiveSparkPlanExec =>
+        getChildrenPlan(Seq(exec.executedPlan))
+      case cmd: CommandResultExec =>
+        getChildrenPlan(Seq(cmd.commandPhysicalPlan))
+      case plan =>
+        getChildrenPlan(Seq(plan))
+    }
+  }
+
+  /**
+   * Check whether the executed plan of a dataframe contains the expected plan 
chain.
+   *
+   * @param df
+   *   : the input dataframe.
+   * @param tag
+   *   : class of the expected plan.
+   * @param childTag
+   *   : class of the expected plan's child.
+   * @tparam T
+   *   : type of the expected plan.
+   * @tparam PT
+   *   : type of the expected plan's child.
+   */
+  def checkSparkOperatorChainMatch[T <: UnaryExecNode, PT <: UnaryExecNode](
+      df: DataFrame)(implicit tag: ClassTag[T], childTag: ClassTag[PT]): Unit 
= {
+    val executedPlan = getExecutedPlan(df)
+    assert(
+      executedPlan.exists(
+        plan =>
+          tag.runtimeClass.isInstance(plan)
+            && childTag.runtimeClass.isInstance(plan.children.head)),
+      s"Expect an operator chain of [${tag.runtimeClass.getSimpleName} ->"
+        + s"${childTag.runtimeClass.getSimpleName}] exists in executedPlan: \n"
+        + s"${executedPlan.last}"
+    )
+  }
+
+  /**
+   * Check whether the executed plan of a dataframe contains the expected plan.
+   * @param df:
+   *   the input dataframe.
+   * @param tag:
+   *   class of the expected plan.
+   * @tparam T:
+   *   type of the expected plan.
+   */
+  def checkGlutenOperatorMatch[T <: GlutenPlan](df: DataFrame)(implicit tag: 
ClassTag[T]): Unit = {
+    val executedPlan = getExecutedPlan(df)
+    assert(
+      executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)),
+      s"Expect ${tag.runtimeClass.getSimpleName} exists " +
+        s"in executedPlan:\n ${executedPlan.last}"
+    )
+  }
+
+  def checkSparkOperatorMatch[T <: SparkPlan](df: DataFrame)(implicit tag: 
ClassTag[T]): Unit = {
+    val executedPlan = getExecutedPlan(df)
+    assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
+  }
 }
 
 object GlutenQueryTest extends Assertions {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to