This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ac34f1de92c6 [SPARK-48280][SQL][FOLLOW-UP] Add expressions that are 
built via expressionBuilder to Expression Walker
ac34f1de92c6 is described below

commit ac34f1de92c6f5cb53d799f00e550a0a204d9eb2
Author: Mihailo Milosevic <[email protected]>
AuthorDate: Thu Sep 19 11:56:10 2024 +0200

    [SPARK-48280][SQL][FOLLOW-UP] Add expressions that are built via 
expressionBuilder to Expression Walker
    
    ### What changes were proposed in this pull request?
    Addition of new expressions to expression walker. This PR also improves 
descriptions of methods in the Suite.
    
    ### Why are the changes needed?
    It was noticed while debugging that startsWith, endsWith and contains are 
not tested with this suite and these expressions represent core of collation 
testing.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Test only.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48162 from mihailom-db/expressionwalkerfollowup.
    
    Authored-by: Mihailo Milosevic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/CollationExpressionWalkerSuite.scala | 148 +++++++++++++++++----
 1 file changed, 121 insertions(+), 27 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
index 2342722c0bb1..1d23774a5169 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.sql.Timestamp
 
 import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
 import org.apache.spark.sql.internal.SqlApiConf
@@ -46,7 +47,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
    *
    * @param inputEntry - List of all input entries that need to be generated
    * @param collationType - Flag defining collation type to use
-   * @return
+   * @return - List of data generated for expression instance creation
    */
   def generateData(
       inputEntry: Seq[Any],
@@ -54,23 +55,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
     inputEntry.map(generateSingleEntry(_, collationType))
   }
 
-  /**
-   * Helper function to generate single entry of data as a string.
-   * @param inputEntry - Single input entry that requires generation
-   * @param collationType - Flag defining collation type to use
-   * @return
-   */
-  def generateDataAsStrings(
-      inputEntry: Seq[AbstractDataType],
-      collationType: CollationType): Seq[Any] = {
-    inputEntry.map(generateInputAsString(_, collationType))
-  }
-
   /**
    * Helper function to generate single entry of data.
    * @param inputEntry - Single input entry that requires generation
    * @param collationType - Flag defining collation type to use
-   * @return
+   * @return - Single input entry data
    */
   def generateSingleEntry(
       inputEntry: Any,
@@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
    *
    * @param inputType    - Single input literal type that requires generation
    * @param collationType - Flag defining collation type to use
-   * @return
+   * @return - Literal/Expression containing expression ready for evaluation
    */
   def generateLiterals(
       inputType: AbstractDataType,
@@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
       }
       case BooleanType => Literal(true)
       case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59"))
+      case DecimalType => Literal((new Decimal).set(5))
       case _: DecimalType => Literal((new Decimal).set(5))
       case _: DoubleType => Literal(5.0)
       case IntegerType | NumericType | IntegralType => Literal(5)
@@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends 
SparkFunSuite with SharedSparkSessi
       case MapType =>
         val key = generateLiterals(StringTypeAnyCollation, collationType)
         val value = generateLiterals(StringTypeAnyCollation, collationType)
-        Literal.create(Map(key -> value))
+        CreateMap(Seq(key, value))
       case MapType(keyType, valueType, _) =>
         val key = generateLiterals(keyType, collationType)
         val value = generateLiterals(valueType, collationType)
-        Literal.create(Map(key -> value))
+        CreateMap(Seq(key, value))
+      case AbstractMapType(keyType, valueType) =>
+        val key = generateLiterals(keyType, collationType)
+        val value = generateLiterals(valueType, collationType)
+        CreateMap(Seq(key, value))
       case StructType =>
         CreateNamedStruct(
           Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, 
collationType),
@@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
    *
    * @param inputType     - Single input type that requires generation
    * @param collationType - Flag defining collation type to use
-   * @return
+   * @return - String representation of a input ready for SQL query
    */
   def generateInputAsString(
       inputType: AbstractDataType,
@@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
         }
       case BooleanType => "True"
       case _: DatetimeType => "date'2016-04-08'"
+      case DecimalType => "5.0"
       case _: DecimalType => "5.0"
       case _: DoubleType => "5.0"
       case IntegerType | NumericType | IntegralType => "5"
@@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
       case MapType(keyType, valueType, _) =>
         "map(" + generateInputAsString(keyType, collationType) + ", " +
           generateInputAsString(valueType, collationType) + ")"
+      case AbstractMapType(keyType, valueType) =>
+        "map(" + generateInputAsString(keyType, collationType) + ", " +
+          generateInputAsString(valueType, collationType) + ")"
       case StructType =>
         "named_struct( 'start', " + 
generateInputAsString(StringTypeAnyCollation, collationType) +
           ", 'end', " + generateInputAsString(StringTypeAnyCollation, 
collationType) + ")"
@@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
    *
    * @param inputType     - Single input type that requires generation
    * @param collationType - Flag defining collation type to use
-   * @return
+   * @return - String representation for SQL query of a inputType
    */
   def generateInputTypeAsStrings(
       inputType: AbstractDataType,
@@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
       case BinaryType => "BINARY"
       case BooleanType => "BOOLEAN"
       case _: DatetimeType => "DATE"
+      case DecimalType => "DECIMAL(2, 1)"
       case _: DecimalType => "DECIMAL(2, 1)"
       case _: DoubleType => "DOUBLE"
       case IntegerType | NumericType | IntegralType => "INT"
@@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
       case MapType(keyType, valueType, _) =>
         "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
           generateInputTypeAsStrings(valueType, collationType) + ">"
+      case AbstractMapType(keyType, valueType) =>
+        "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
+          generateInputTypeAsStrings(valueType, collationType) + ">"
       case StructType =>
         "struct<start:" + generateInputTypeAsStrings(StringTypeAnyCollation, 
collationType) +
           ", end:" +
@@ -287,7 +289,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
   /**
    * Helper function to extract types of relevance
    * @param inputType
-   * @return
+   * @return - Boolean that represents if inputType has/is a StringType
    */
   def hasStringType(inputType: AbstractDataType): Boolean = {
     inputType match {
@@ -300,7 +302,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
       case AbstractArrayType(elementType) => hasStringType(elementType)
       case TypeCollection(typeCollection) =>
         typeCollection.exists(hasStringType)
-      case StructType => true
       case StructType(fields) => fields.exists(sf => 
hasStringType(sf.dataType))
       case _ => false
     }
@@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
    * Helper function to replace expected parameters with expected input types.
    * @param inputTypes - Input types generated by ExpectsInputType.inputTypes
    * @param params - Parameters that are read from expression info
-   * @return
+   * @return - List of parameters where Expressions are replaced with input 
types
    */
   def replaceExpressions(inputTypes: Seq[AbstractDataType], params: 
Seq[Class[_]]): Seq[Any] = {
     (inputTypes, params) match {
@@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
 
   /**
    * Helper method to extract relevant expressions that can be walked over.
-   * @return
+   * @return - (List of relevant expressions that expect input, List of 
expressions to skip)
    */
   def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = {
     var expressionCounter = 0
@@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite 
with SharedSparkSessi
     (funInfos, toSkip)
   }
 
+  /**
+   * Helper method to extract relevant expressions that can be walked over but 
are built with
+   * expression builder.
+   *
+   * @return - (List of expressions that are relevant builders, List of 
expressions to skip)
+   */
+  def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = {
+    var builderExpressionCounter = 0
+    val funInfos = spark.sessionState.functionRegistry.listFunction().map { 
funcId =>
+      spark.sessionState.catalog.lookupFunctionInfo(funcId)
+    }.filter(funInfo => {
+      // make sure that there is a constructor.
+      val cl = Utils.classForName(funInfo.getClassName)
+      cl.isAssignableFrom(classOf[ExpressionBuilder])
+    }).filter(funInfo => {
+      builderExpressionCounter = builderExpressionCounter + 1
+      val cl = Utils.classForName(funInfo.getClassName)
+      val method = cl.getMethod("build",
+        Utils.classForName("java.lang.String"),
+        Utils.classForName("scala.collection.Seq"))
+      var input: Seq[Expression] = Seq.empty
+      var i = 0
+      for (_ <- 1 to 10) {
+        input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
+        try {
+          method.invoke(null, funInfo.getClassName, 
input).asInstanceOf[ExpectsInputTypes]
+        }
+        catch {
+          case _: Exception => i = i + 1
+        }
+      }
+      if (i == 10) false
+      else true
+    }).toArray
+
+    logInfo("Total number of expression that are built: " + 
builderExpressionCounter)
+    logInfo("Number of extracted expressions of relevance: " + funInfos.length)
+
+    (funInfos, List())
+  }
+
   /**
    * Helper function to generate string of an expression suitable for 
execution.
    * @param expr - Expression that needs to be converted
@@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends 
SparkFunSuite with SharedSparkSessi
    * 5) Otherwise, check if exceptions are the same
    */
   test("SPARK-48280: Expression Walker for expression evaluation") {
-    val (funInfos, toSkip) = extractRelevantExpressions()
+    val (funInfosExpr, toSkip) = extractRelevantExpressions()
+    val (funInfosBuild, _) = extractRelevantBuilders()
+    val funInfos = funInfosExpr ++ funInfosBuild
 
     for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
-      val cl = Utils.classForName(f.getClassName)
+      val TempCl = Utils.classForName(f.getClassName)
+      val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
+        val clTemp = Utils.classForName(f.getClassName)
+        val method = clTemp.getMethod("build",
+          Utils.classForName("java.lang.String"),
+          Utils.classForName("scala.collection.Seq"))
+        val instance = {
+          var input: Seq[Expression] = Seq.empty
+          var result: Expression = null
+          for (_ <- 1 to 10) {
+            input = input :+ generateLiterals(StringTypeAnyCollation, 
Utf8Binary)
+            try {
+              val tempResult = method.invoke(null, f.getClassName, input)
+              if (result == null) result = tempResult.asInstanceOf[Expression]
+            }
+            catch {
+              case _: Exception =>
+            }
+          }
+          result
+        }
+        instance.getClass
+      }
+      else Utils.classForName(f.getClassName)
+
       val headConstructor = cl.getConstructors
         .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => 
a._2)._1
       val params = headConstructor.getParameters.map(p => p.getType)
@@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends 
SparkFunSuite with SharedSparkSessi
    * 5) Otherwise, check if exceptions are the same
    */
   test("SPARK-48280: Expression Walker for codeGen generation") {
-    val (funInfos, toSkip) = extractRelevantExpressions()
+    val (funInfosExpr, toSkip) = extractRelevantExpressions()
+    val (funInfosBuild, _) = extractRelevantBuilders()
+    val funInfos = funInfosExpr ++ funInfosBuild
 
     for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
-      val cl = Utils.classForName(f.getClassName)
+      val TempCl = Utils.classForName(f.getClassName)
+      val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
+        val clTemp = Utils.classForName(f.getClassName)
+        val method = clTemp.getMethod("build",
+          Utils.classForName("java.lang.String"),
+          Utils.classForName("scala.collection.Seq"))
+        val instance = {
+          var input: Seq[Expression] = Seq.empty
+          var result: Expression = null
+          for (_ <- 1 to 10) {
+            input = input :+ generateLiterals(StringTypeAnyCollation, 
Utf8Binary)
+            try {
+              val tempResult = method.invoke(null, f.getClassName, input)
+              if (result == null) result = tempResult.asInstanceOf[Expression]
+            }
+            catch {
+              case _: Exception =>
+            }
+          }
+          result
+        }
+        instance.getClass
+      }
+      else Utils.classForName(f.getClassName)
+
       val headConstructor = cl.getConstructors
         .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => 
a._2)._1
       val params = headConstructor.getParameters.map(p => p.getType)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to