This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new da4711207 [VL] Validate binary expressions with their accepted types
(#6521)
da4711207 is described below
commit da4711207d3e188f0cc179eaf2bfb13ff4c1b6af
Author: PHILO-HE <[email protected]>
AuthorDate: Wed Aug 14 20:30:13 2024 +0800
[VL] Validate binary expressions with their accepted types (#6521)
---
.../sql/GlutenExpressionDataTypesValidation.scala | 148 ++++++++++++++++++---
1 file changed, 129 insertions(+), 19 deletions(-)
diff --git
a/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
b/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
index c8b2aaba2..170555e5f 100644
---
a/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
+++
b/gluten-ut/test/src/test/scala/org/apache/spark/sql/GlutenExpressionDataTypesValidation.scala
@@ -24,11 +24,13 @@ import org.apache.gluten.utils.{BackendTestUtils,
SystemParameters}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
+import
org.apache.spark.sql.catalyst.expressions.{BinaryArrayExpressionWithImplicitCast,
_}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._
+import scala.collection.mutable.Buffer
+
class GlutenExpressionDataTypesValidation extends WholeStageTransformerSuite {
protected val resourcePath: String = null
protected val fileFormat: String = null
@@ -99,6 +101,19 @@ class GlutenExpressionDataTypesValidation extends
WholeStageTransformerSuite {
ProjectExecTransformer(namedExpr, DummyPlan())
}
+ def validateExpr(targetExpr: Expression): Unit = {
+ val glutenProject = generateGlutenProjectPlan(targetExpr)
+ if (targetExpr.resolved && glutenProject.doValidate().ok()) {
+ logInfo(
+ "## validation passes: " + targetExpr.getClass.getSimpleName + "(" +
+ targetExpr.children.map(_.dataType.toString).mkString(", ") + ")")
+ } else {
+ logInfo(
+ "!! validation fails: " + targetExpr.getClass.getSimpleName + "(" +
+ targetExpr.children.map(_.dataType.toString).mkString(", ") + ")")
+ }
+ }
+
test("cast") {
for (from <- allPrimitiveDataTypes ++ allComplexDataTypes) {
for (to <- allPrimitiveDataTypes ++ allComplexDataTypes) {
@@ -120,21 +135,34 @@ class GlutenExpressionDataTypesValidation extends
WholeStageTransformerSuite {
test("unary expressions with expected input types") {
val functionRegistry = spark.sessionState.functionRegistry
val sparkBuiltInFunctions = functionRegistry.listFunction()
+ val exceptionalList: Buffer[Expression] = Buffer()
+
for (func <- sparkBuiltInFunctions) {
val builder = functionRegistry.lookupFunctionBuilder(func).get
- var expr: Expression = null
- try {
- // Instantiate an expression with null input. Just for obtaining the
instance for checking
- // its allowed input types.
- expr = builder(Seq(null))
- } catch {
- // Ignore the exception as some expression builders require more than
one input.
- case _: Throwable =>
+ val expr: Expression = {
+ try {
+ // Instantiate an expression with null input. Just for obtaining the
instance for checking
+ // its allowed input types.
+ builder(Seq(null))
+ } catch {
+ // Ignore the exception as some expression builders require more
than one input.
+ case _: Throwable => null
+ }
+ }
+ val needsValidation = if (expr == null) {
+ false
+ } else {
+ expr match {
+ // Validated separately.
+ case _: Cast => false
+ case _: ExpectsInputTypes if expr.isInstanceOf[UnaryExpression] =>
true
+ case _ =>
+ exceptionalList += expr
+ false
+ }
}
- if (
- expr != null && expr.isInstanceOf[ExpectsInputTypes] &&
expr.isInstanceOf[UnaryExpression]
- ) {
- val acceptedTypes = allPrimitiveDataTypes.filter(
+ if (needsValidation) {
+ val acceptedTypes = allPrimitiveDataTypes ++
allComplexDataTypes.filter(
expr.asInstanceOf[ExpectsInputTypes].inputTypes.head.acceptsType(_))
if (acceptedTypes.isEmpty) {
logWarning("Any given type is not accepted for " +
expr.getClass.getSimpleName)
@@ -144,15 +172,97 @@ class GlutenExpressionDataTypesValidation extends
WholeStageTransformerSuite {
val child = generateChildExpression(t)
// Builds an expression whose child's type is really accepted in
Spark.
val targetExpr = builder(Seq(child))
- val glutenProject = generateGlutenProjectPlan(targetExpr)
- if (targetExpr.resolved && glutenProject.doValidate().ok()) {
- logInfo("## validation passes: " +
targetExpr.getClass.getSimpleName + "(" + t + ")")
- } else {
- logInfo("!! validation fails: " +
targetExpr.getClass.getSimpleName + "(" + t + ")")
- }
+ validateExpr(targetExpr)
})
}
}
+
+ logWarning("Exceptional list:\n" + exceptionalList.mkString(", "))
}
+ def hasImplicitCast(expr: Expression): Boolean = expr match {
+ case _: ImplicitCastInputTypes => true
+ case _: BinaryOperator => true
+ case _ => false
+ }
+
+ test("binary expressions with expected input types") {
+ val functionRegistry = spark.sessionState.functionRegistry
+ val exceptionalList: Buffer[Expression] = Buffer()
+
+ val sparkBuiltInFunctions = functionRegistry.listFunction()
+ sparkBuiltInFunctions.foreach(
+ func => {
+ val builder = functionRegistry.lookupFunctionBuilder(func).get
+ val expr: Expression = {
+ try {
+ // Instantiate an expression with null input. Just for obtaining
the instance for
+ // checking its allowed input types.
+ builder(Seq(null, null))
+ } catch {
+ // Ignore the exception as some expression builders that don't
require exact two input.
+ case _: Throwable => null
+ }
+ }
+ val needsValidation = if (expr == null) {
+ false
+ } else {
+ expr match {
+ // Requires left/right child's DataType to determine inputTypes.
+ case _: BinaryArrayExpressionWithImplicitCast =>
+ exceptionalList += expr
+ false
+ case _: ExpectsInputTypes if expr.isInstanceOf[BinaryExpression]
=> true
+ case _ =>
+ exceptionalList += expr
+ false
+ }
+ }
+
+ if (needsValidation) {
+ var acceptedLeftTypes: Seq[DataType] = Seq.empty
+ var acceptedRightTypes: Seq[DataType] = Seq.empty
+ try {
+ acceptedLeftTypes = allPrimitiveDataTypes ++
allComplexDataTypes.filter(
+
expr.asInstanceOf[ExpectsInputTypes].inputTypes(0).acceptsType(_))
+ acceptedRightTypes = allPrimitiveDataTypes ++
allComplexDataTypes.filter(
+
expr.asInstanceOf[ExpectsInputTypes].inputTypes(1).acceptsType(_))
+ } catch {
+ case _: java.lang.NullPointerException =>
+ }
+
+ if (acceptedLeftTypes.isEmpty || acceptedRightTypes.isEmpty) {
+ logWarning("Any given type is not accepted for " +
expr.getClass.getSimpleName)
+ }
+ val leftChildList = acceptedLeftTypes.map(
+ t => {
+ generateChildExpression(t)
+ })
+ if (hasImplicitCast(expr)) {
+ leftChildList.foreach(
+ left => {
+ // Spark's implicit cast makes same input types.
+ val targetExpr = builder(Seq(left, left))
+ validateExpr(targetExpr)
+ })
+ } else {
+ val rightChildList = acceptedRightTypes.map(
+ t => {
+ generateChildExpression(t)
+ })
+ leftChildList.foreach(
+ left => {
+ rightChildList.foreach(
+ right => {
+ // Builds an expression whose child's type is really
accepted in Spark.
+ val targetExpr = builder(Seq(left, right))
+ validateExpr(targetExpr)
+ })
+ })
+ }
+ }
+ })
+
+ logWarning("Exceptional list:\n" + exceptionalList.mkString(", "))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]