This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 512f625 [SPARK-34723][SQL] Correct parameter type for subexpression
elimination under whole-stage
512f625 is described below
commit 512f6258071666c3da07f6e6c6299d653a1cc370
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Mar 13 00:05:41 2021 -0800
[SPARK-34723][SQL] Correct parameter type for subexpression elimination
under whole-stage
### What changes were proposed in this pull request?
This patch proposes to fix incorrect parameter type for subexpression
elimination under whole-stage.
### Why are the changes needed?
If the parameter is a byte array, the subexpression elimination under
wholestage codegen will use incorrect parameter type and cause compile error.
Although Spark can automatically fallback to interpreted mode, we should fix it.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually test with customer application. Unit test.
Closes #31814 from viirya/SPARK-34723.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
(cherry picked from commit 86baa36eebf72b72981830ddb8085950507a4bfa)
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
.../expressions/codegen/CodeGenerator.scala | 3 +-
.../SubexpressionEliminationSuite.scala | 62 ++++++++++++++++++++--
2 files changed, 61 insertions(+), 4 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 1ff4a93..6e6b946 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1089,7 +1089,8 @@ class CodegenContext extends Logging {
// Generate the code for this expression tree and wrap it in a
function.
val fnName = freshName("subExpr")
val inputVars = inputVarsForAllFuncs(i)
- val argList = inputVars.map(v => s"${v.javaType.getName}
${v.variableName}")
+ val argList =
+ inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)}
${v.variableName}")
val returnType = javaType(expr.dataType)
val fn =
s"""
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 0147c6c..65671d2 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
-class SubexpressionEliminationSuite extends SparkFunSuite {
+class SubexpressionEliminationSuite extends SparkFunSuite with
ExpressionEvalHelper {
test("Semantic equals and hash") {
val a: AttributeReference = AttributeReference("name", IntegerType)()
val id = {
@@ -253,6 +254,61 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0)
}
+
+ test("SPARK-34723: Correct parameter type for subexpression elimination
under whole-stage") {
+ withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") {
+ val str = BoundReference(0, BinaryType, false)
+ val pos = BoundReference(1, IntegerType, false)
+
+ val substr = new Substring(str, pos)
+
+ val add = Add(Length(substr), Literal(1))
+ val add2 = Add(Length(substr), Literal(2))
+
+ val ctx = new CodegenContext()
+ val exprs = Seq(add, add2)
+
+ val oneVar = ctx.freshVariable("str", BinaryType)
+ val twoVar = ctx.freshVariable("pos", IntegerType)
+ ctx.addMutableState("byte[]", oneVar, forceInline = true, useFreshName =
false)
+ ctx.addMutableState("int", twoVar, useFreshName = false)
+
+ ctx.INPUT_ROW = null
+ ctx.currentVars = Seq(
+ ExprCode(TrueLiteral, oneVar),
+ ExprCode(TrueLiteral, twoVar))
+
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
+ ctx.withSubExprEliminationExprs(subExprs.states) {
+ exprs.map(_.genCode(ctx))
+ }
+ val subExprsCode = subExprs.codes.mkString("\n")
+
+ val codeBody = s"""
+ public java.lang.Object generate(Object[] references) {
+ return new TestCode(references);
+ }
+
+ class TestCode {
+ ${ctx.declareMutableStates()}
+
+ public TestCode(Object[] references) {
+ }
+
+ public void initialize(int partitionIndex) {
+ ${subExprsCode}
+ }
+
+ ${ctx.declareAddedFunctions()}
+ }
+ """
+
+ val code = CodeFormatter.stripOverlappingComments(
+ new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+
+ CodeGenerator.compile(code)
+ }
+ }
}
case class CodegenFallbackExpression(child: Expression)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]