This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new da4711207 [VL] Validate binary expressions with their accepted types 
(#6521)
da4711207 is described below

commit da4711207d3e188f0cc179eaf2bfb13ff4c1b6af
Author: PHILO-HE <[email protected]>
AuthorDate: Wed Aug 14 20:30:13 2024 +0800

    [VL] Validate binary expressions with their accepted types (#6521)
---
 .../sql/GlutenExpressionDataTypesValidation.scala  | 148 ++++++++++++++++++---
 1 file changed, 129 insertions(+), 19 deletions(-)

diff --git 
a/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
 
b/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
index c8b2aaba2..170555e5f 100644
--- 
a/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
+++ 
b/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
@@ -24,11 +24,13 @@ import org.apache.gluten.utils.{BackendTestUtils, 
SystemParameters}
 import org.apache.spark.SparkConf
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.{BinaryArrayExpressionWithImplicitCast,
 _}
 import org.apache.spark.sql.execution.LeafExecNode
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types._
 
+import scala.collection.mutable.Buffer
+
 class GlutenExpressionDataTypesValidation extends WholeStageTransformerSuite {
   protected val resourcePath: String = null
   protected val fileFormat: String = null
@@ -99,6 +101,19 @@ class GlutenExpressionDataTypesValidation extends 
WholeStageTransformerSuite {
     ProjectExecTransformer(namedExpr, DummyPlan())
   }
 
+  def validateExpr(targetExpr: Expression): Unit = {
+    val glutenProject = generateGlutenProjectPlan(targetExpr)
+    if (targetExpr.resolved && glutenProject.doValidate().ok()) {
+      logInfo(
+        "## validation passes: " + targetExpr.getClass.getSimpleName + "(" +
+          targetExpr.children.map(_.dataType.toString).mkString(", ") + ")")
+    } else {
+      logInfo(
+        "!! validation fails: " + targetExpr.getClass.getSimpleName + "(" +
+          targetExpr.children.map(_.dataType.toString).mkString(", ") + ")")
+    }
+  }
+
   test("cast") {
     for (from <- allPrimitiveDataTypes ++ allComplexDataTypes) {
       for (to <- allPrimitiveDataTypes ++ allComplexDataTypes) {
@@ -120,21 +135,34 @@ class GlutenExpressionDataTypesValidation extends 
WholeStageTransformerSuite {
   test("unary expressions with expected input types") {
     val functionRegistry = spark.sessionState.functionRegistry
     val sparkBuiltInFunctions = functionRegistry.listFunction()
+    val exceptionalList: Buffer[Expression] = Buffer()
+
     for (func <- sparkBuiltInFunctions) {
       val builder = functionRegistry.lookupFunctionBuilder(func).get
-      var expr: Expression = null
-      try {
-        // Instantiate an expression with null input. Just for obtaining the 
instance for checking
-        // its allowed input types.
-        expr = builder(Seq(null))
-      } catch {
-        // Ignore the exception as some expression builders require more than 
one input.
-        case _: Throwable =>
+      val expr: Expression = {
+        try {
+          // Instantiate an expression with null input. Just for obtaining the 
instance for checking
+          // its allowed input types.
+          builder(Seq(null))
+        } catch {
+          // Ignore the exception as some expression builders require more 
than one input.
+          case _: Throwable => null
+        }
+      }
+      val needsValidation = if (expr == null) {
+        false
+      } else {
+        expr match {
+          // Validated separately.
+          case _: Cast => false
+          case _: ExpectsInputTypes if expr.isInstanceOf[UnaryExpression] => 
true
+          case _ =>
+            exceptionalList += expr
+            false
+        }
       }
-      if (
-        expr != null && expr.isInstanceOf[ExpectsInputTypes] && 
expr.isInstanceOf[UnaryExpression]
-      ) {
-        val acceptedTypes = allPrimitiveDataTypes.filter(
+      if (needsValidation) {
+        val acceptedTypes = allPrimitiveDataTypes ++ 
allComplexDataTypes.filter(
           expr.asInstanceOf[ExpectsInputTypes].inputTypes.head.acceptsType(_))
         if (acceptedTypes.isEmpty) {
           logWarning("Any given type is not accepted for " + 
expr.getClass.getSimpleName)
@@ -144,15 +172,97 @@ class GlutenExpressionDataTypesValidation extends 
WholeStageTransformerSuite {
             val child = generateChildExpression(t)
             // Builds an expression whose child's type is really accepted in 
Spark.
             val targetExpr = builder(Seq(child))
-            val glutenProject = generateGlutenProjectPlan(targetExpr)
-            if (targetExpr.resolved && glutenProject.doValidate().ok()) {
-              logInfo("## validation passes: " + 
targetExpr.getClass.getSimpleName + "(" + t + ")")
-            } else {
-              logInfo("!! validation fails: " + 
targetExpr.getClass.getSimpleName + "(" + t + ")")
-            }
+            validateExpr(targetExpr)
           })
       }
     }
+
+    logWarning("Exceptional list:\n" + exceptionalList.mkString(", "))
   }
 
+  def hasImplicitCast(expr: Expression): Boolean = expr match {
+    case _: ImplicitCastInputTypes => true
+    case _: BinaryOperator => true
+    case _ => false
+  }
+
+  test("binary expressions with expected input types") {
+    val functionRegistry = spark.sessionState.functionRegistry
+    val exceptionalList: Buffer[Expression] = Buffer()
+
+    val sparkBuiltInFunctions = functionRegistry.listFunction()
+    sparkBuiltInFunctions.foreach(
+      func => {
+        val builder = functionRegistry.lookupFunctionBuilder(func).get
+        val expr: Expression = {
+          try {
+            // Instantiate an expression with null input. Just for obtaining 
the instance for
+            // checking its allowed input types.
+            builder(Seq(null, null))
+          } catch {
+            // Ignore the exception as some expression builders that don't 
require exact two input.
+            case _: Throwable => null
+          }
+        }
+        val needsValidation = if (expr == null) {
+          false
+        } else {
+          expr match {
+            // Requires left/right child's DataType to determine inputTypes.
+            case _: BinaryArrayExpressionWithImplicitCast =>
+              exceptionalList += expr
+              false
+            case _: ExpectsInputTypes if expr.isInstanceOf[BinaryExpression] 
=> true
+            case _ =>
+              exceptionalList += expr
+              false
+          }
+        }
+
+        if (needsValidation) {
+          var acceptedLeftTypes: Seq[DataType] = Seq.empty
+          var acceptedRightTypes: Seq[DataType] = Seq.empty
+          try {
+            acceptedLeftTypes = allPrimitiveDataTypes ++ 
allComplexDataTypes.filter(
+              
expr.asInstanceOf[ExpectsInputTypes].inputTypes(0).acceptsType(_))
+            acceptedRightTypes = allPrimitiveDataTypes ++ 
allComplexDataTypes.filter(
+              
expr.asInstanceOf[ExpectsInputTypes].inputTypes(1).acceptsType(_))
+          } catch {
+            case _: java.lang.NullPointerException =>
+          }
+
+          if (acceptedLeftTypes.isEmpty || acceptedRightTypes.isEmpty) {
+            logWarning("Any given type is not accepted for " + 
expr.getClass.getSimpleName)
+          }
+          val leftChildList = acceptedLeftTypes.map(
+            t => {
+              generateChildExpression(t)
+            })
+          if (hasImplicitCast(expr)) {
+            leftChildList.foreach(
+              left => {
+                // Spark's implicit cast makes same input types.
+                val targetExpr = builder(Seq(left, left))
+                validateExpr(targetExpr)
+              })
+          } else {
+            val rightChildList = acceptedRightTypes.map(
+              t => {
+                generateChildExpression(t)
+              })
+            leftChildList.foreach(
+              left => {
+                rightChildList.foreach(
+                  right => {
+                    // Builds an expression whose child's type is really 
accepted in Spark.
+                    val targetExpr = builder(Seq(left, right))
+                    validateExpr(targetExpr)
+                  })
+              })
+          }
+        }
+      })
+
+    logWarning("Exceptional list:\n" + exceptionalList.mkString(", "))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to