This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f718b025d87 [SPARK-43802][SQL] Fix codegen for unhex and unbase64 with
failOnError=true
f718b025d87 is described below
commit f718b025d87ae3726210c60ff71cb34917b32f51
Author: Adam Binford <[email protected]>
AuthorDate: Fri May 26 20:37:14 2023 +0300
[SPARK-43802][SQL] Fix codegen for unhex and unbase64 with failOnError=true
### What changes were proposed in this pull request?
Fixes an error with codegen for unhex and unbase64 expression when
failOnError is enabled introduced in https://github.com/apache/spark/pull/37483.
### Why are the changes needed?
Codegen fails and Spark falls back to interpreted evaluation:
```
Caused by: org.codehaus.commons.compiler.CompileException: File
'generated.java', Line 47, Column 1: failed to compile:
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47,
Column 1: Unknown variable or type "BASE64"
```
in the code block:
```
/* 107 */ if
(!org.apache.spark.sql.catalyst.expressions.UnBase64.isValidBase64(project_value_1))
{
/* 108 */ throw
QueryExecutionErrors.invalidInputInConversionError(
/* 109 */ ((org.apache.spark.sql.types.BinaryType$)
references[1] /* to */),
/* 110 */ project_value_1,
/* 111 */ BASE64,
/* 112 */ "try_to_binary");
/* 113 */ }
```
### Does this PR introduce _any_ user-facing change?
Bug fix.
### How was this patch tested?
Added to the existing tests so evaluate an expression with failOnError
enabled to test that path of the codegen.
Closes #41317 from Kimahriman/bug-to-binary-codegen.
Authored-by: Adam Binford <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../sql/catalyst/expressions/mathExpressions.scala | 3 +-
.../catalyst/expressions/stringExpressions.scala | 3 +-
.../expressions/MathExpressionsSuite.scala | 3 ++
.../expressions/StringExpressionsSuite.scala | 4 +-
.../sql/errors/QueryExecutionErrorsSuite.scala | 46 ++++++++++++++++------
5 files changed, 43 insertions(+), 16 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index dcc821a24ea..add59a38b72 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1172,14 +1172,13 @@ case class Unhex(child: Expression, failOnError:
Boolean = false)
nullSafeCodeGen(ctx, ev, c => {
val hex = Hex.getClass.getName.stripSuffix("$")
val maybeFailOnErrorCode = if (failOnError) {
- val format = UTF8String.fromString("BASE64");
val binaryType = ctx.addReferenceObj("to", BinaryType,
BinaryType.getClass.getName)
s"""
|if (${ev.value} == null) {
| throw QueryExecutionErrors.invalidInputInConversionError(
| $binaryType,
| $c,
- | $format,
+ | UTF8String.fromString("HEX"),
| "try_to_binary");
|}
|""".stripMargin
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 347dff0f4c4..03596ac40b1 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -2472,14 +2472,13 @@ case class UnBase64(child: Expression, failOnError:
Boolean = false)
nullSafeCodeGen(ctx, ev, child => {
val maybeValidateInputCode = if (failOnError) {
val unbase64 = UnBase64.getClass.getName.stripSuffix("$")
- val format = UTF8String.fromString("BASE64");
val binaryType = ctx.addReferenceObj("to", BinaryType,
BinaryType.getClass.getName)
s"""
|if (!$unbase64.isValidBase64($child)) {
| throw QueryExecutionErrors.invalidInputInConversionError(
| $binaryType,
| $child,
- | $format,
+ | UTF8String.fromString("BASE64"),
| "try_to_binary");
|}
""".stripMargin
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 437f7ddee01..823a6d2ce86 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -615,6 +615,9 @@ class MathExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(Unhex(Literal("GG")), null)
checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+
+ // failOnError
+ checkEvaluation(Unhex(Literal("12345"), true), Array[Byte](1, 35, 69))
// scalastyle:off
// Turn off scala style for non-ascii chars
checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")),
"δΈιη".getBytes(StandardCharsets.UTF_8))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index a27af7d2439..f320012d131 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -468,7 +468,9 @@ class StringExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==",
create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null,
create_row("abdef"))
- checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))
+
+ // failOnError
+ checkEvaluation(Base64(UnBase64(a, true)), "AQIDBA==",
create_row("AQIDBA=="))
checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
checkEvaluation(Base64(b), "", create_row(Array.empty[Byte]))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 4bfab92ccb1..c37722133cb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame,
Dataset, QueryTest, R
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator}
import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean
import org.apache.spark.sql.catalyst.util.BadRecordException
@@ -57,17 +58,40 @@ class QueryExecutionErrorsSuite
import testImplicits._
- test("CONVERSION_INVALID_INPUT: to_binary conversion function") {
- checkError(
- exception = intercept[SparkIllegalArgumentException] {
- sql("select to_binary('???', 'base64')").collect()
- },
- errorClass = "CONVERSION_INVALID_INPUT",
- parameters = Map(
- "str" -> "'???'",
- "fmt" -> "'BASE64'",
- "targetType" -> "\"BINARY\"",
- "suggestion" -> "`try_to_binary`"))
+ test("CONVERSION_INVALID_INPUT: to_binary conversion function base64") {
+ for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+ val exception = intercept[SparkException] {
+ Seq(("???")).toDF("a").selectExpr("to_binary(a, 'base64')").collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException]
+ checkError(
+ exception,
+ errorClass = "CONVERSION_INVALID_INPUT",
+ parameters = Map(
+ "str" -> "'???'",
+ "fmt" -> "'BASE64'",
+ "targetType" -> "\"BINARY\"",
+ "suggestion" -> "`try_to_binary`"))
+ }
+ }
+ }
+
+ test("CONVERSION_INVALID_INPUT: to_binary conversion function hex") {
+ for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+ val exception = intercept[SparkException] {
+ Seq(("???")).toDF("a").selectExpr("to_binary(a, 'hex')").collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException]
+ checkError(
+ exception,
+ errorClass = "CONVERSION_INVALID_INPUT",
+ parameters = Map(
+ "str" -> "'???'",
+ "fmt" -> "'HEX'",
+ "targetType" -> "\"BINARY\"",
+ "suggestion" -> "`try_to_binary`"))
+ }
+ }
}
private def getAesInputs(): (DataFrame, DataFrame) = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]