This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new d20d549  [SPARK-34723][SQL] Correct parameter type for subexpression 
elimination under whole-stage
d20d549 is described below

commit d20d549fdf644203a5e68394b9e99e524aaee9e5
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Mar 13 00:05:41 2021 -0800

    [SPARK-34723][SQL] Correct parameter type for subexpression elimination 
under whole-stage
    
    ### What changes were proposed in this pull request?
    
    This patch proposes to fix incorrect parameter type for subexpression 
elimination under whole-stage.
    
    ### Why are the changes needed?
    
    If the parameter is a byte array, the subexpression elimination under 
wholestage codegen will use incorrect parameter type and cause compile error. 
Although Spark can automatically fallback to interpreted mode, we should fix it.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Manually test with customer application. Unit test.
    
    Closes #31814 from viirya/SPARK-34723.
    
    Authored-by: Liang-Chi Hsieh <[email protected]>
    Signed-off-by: Liang-Chi Hsieh <[email protected]>
    (cherry picked from commit 86baa36eebf72b72981830ddb8085950507a4bfa)
    Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
 .../expressions/codegen/CodeGenerator.scala        |  3 +-
 .../SubexpressionEliminationSuite.scala            | 62 ++++++++++++++++++++--
 2 files changed, 61 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 510696e..74c6ea5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1077,7 +1077,8 @@ class CodegenContext extends Logging {
           // Generate the code for this expression tree and wrap it in a 
function.
           val fnName = freshName("subExpr")
           val inputVars = inputVarsForAllFuncs(i)
-          val argList = inputVars.map(v => s"${v.javaType.getName} 
${v.variableName}")
+          val argList =
+            inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} 
${v.variableName}")
           val returnType = javaType(expr.dataType)
           val fn =
             s"""
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 1fa185c..569928e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -17,10 +17,11 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
 
-class SubexpressionEliminationSuite extends SparkFunSuite {
+class SubexpressionEliminationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   test("Semantic equals and hash") {
     val a: AttributeReference = AttributeReference("name", IntegerType)()
     val id = {
@@ -161,6 +162,61 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
     // only ifExpr and its predicate expression
     assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
   }
+
+  test("SPARK-34723: Correct parameter type for subexpression elimination 
under whole-stage") {
+    withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") {
+      val str = BoundReference(0, BinaryType, false)
+      val pos = BoundReference(1, IntegerType, false)
+
+      val substr = new Substring(str, pos)
+
+      val add = Add(Length(substr), Literal(1))
+      val add2 = Add(Length(substr), Literal(2))
+
+      val ctx = new CodegenContext()
+      val exprs = Seq(add, add2)
+
+      val oneVar = ctx.freshVariable("str", BinaryType)
+      val twoVar = ctx.freshVariable("pos", IntegerType)
+      ctx.addMutableState("byte[]", oneVar, forceInline = true, useFreshName = 
false)
+      ctx.addMutableState("int", twoVar, useFreshName = false)
+
+      ctx.INPUT_ROW = null
+      ctx.currentVars = Seq(
+        ExprCode(TrueLiteral, oneVar),
+        ExprCode(TrueLiteral, twoVar))
+
+      val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
+      ctx.withSubExprEliminationExprs(subExprs.states) {
+        exprs.map(_.genCode(ctx))
+      }
+      val subExprsCode = subExprs.codes.mkString("\n")
+
+      val codeBody = s"""
+        public java.lang.Object generate(Object[] references) {
+          return new TestCode(references);
+        }
+
+        class TestCode {
+          ${ctx.declareMutableStates()}
+
+          public TestCode(Object[] references) {
+          }
+
+          public void initialize(int partitionIndex) {
+            ${subExprsCode}
+          }
+
+          ${ctx.declareAddedFunctions()}
+        }
+      """
+
+      val code = CodeFormatter.stripOverlappingComments(
+        new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+
+      CodeGenerator.compile(code)
+    }
+  }
 }
 
 case class CodegenFallbackExpression(child: Expression)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to