This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
commit 0608361e5b14dcb08d631701894bd1cea7e39fdd Author: Kris Mok <[email protected]> AuthorDate: Wed Feb 12 15:19:16 2020 +0900 [SPARK-30795][SQL] Spark SQL codegen's code() interpolator should treat escapes like Scala's StringContext.s() ### What changes were proposed in this pull request? This PR proposes to make the `code` string interpolator treat escapes the same way as Scala's builtin `StringContext.s()` string interpolator. This will remove the need for an ugly workaround in `Like` expression's codegen. ### Why are the changes needed? The `code()` string interpolator in Spark SQL's code generator should treat escapes like Scala's builtin `StringContext.s()` interpolator, i.e. it should treat escapes in the code parts, and should not treat escapes in the input arguments. For example, ```scala val arg = "This is an argument." val str = s"This is string part 1. $arg This is string part 2." val code = code"This is string part 1. $arg This is string part 2." assert(code.toString == str) ``` We should expect the `code()` interpolator to produce the same result as the `StringContext.s()` interpolator, where only escapes in the string parts should be treated, while the args should be kept verbatim. But in the current implementation, due to the eager folding of code parts and literal input args, the escape treatment is incorrectly done on both code parts and literal args. That causes a problem when an arg contains escape sequences and wants to preserve that in the final produced code string. For example, in `Like` expression's codegen, there's an ugly workaround for this bug: ```scala // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { s"""\\\\\\$escapeChar""" } else { escapeChar } ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new unit test case in `CodeBlockSuite`. Closes #27544 from rednaxelafx/fix-code-string-interpolator. Authored-by: Kris Mok <[email protected]> Signed-off-by: HyukjinKwon <[email protected]> --- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 13 +++++++++---- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 13 ++++--------- .../sql/catalyst/expressions/codegen/CodeBlockSuite.scala | 12 ++++++++++++ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index d9393b9..dff2589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -223,6 +223,11 @@ object Block { implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { + /** + * A string interpolator that retains references to the `JavaCode` inputs, and behaves like + * the Scala builtin StringContext.s() interpolator otherwise, i.e. it will treat escapes in + * the code parts, and will not treat escapes in the input arguments. + */ def code(args: Any*): Block = { sc.checkLengths(args) if (sc.parts.length == 0) { @@ -250,7 +255,7 @@ object Block { val inputs = args.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) while (strings.hasNext) { val input = inputs.next input match { @@ -262,7 +267,7 @@ object Block { case _ => buf.append(input) } - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) } codeParts += buf.toString @@ -286,10 +291,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends val strings = codeParts.iterator val inputs = blockInputs.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) while (strings.hasNext) { buf.append(inputs.next) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) } buf.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 32a653d..ac620b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -158,19 +158,14 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } else { val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") - // We need double escape to avoid org.codehaus.commons.compiler.CompileException. - // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. - // '\"' will cause exception 'Line break in literal not allowed'. - val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { - s"""\\\\\\$escapeChar""" - } else { - escapeChar - } + // We need to escape the escapeChar to make sure the generated code is valid. + // Otherwise we'll hit org.codehaus.commons.compiler.CompileException. + val escapedEscapeChar = StringEscapeUtils.escapeJava(escapeChar.toString) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile( - $escapeFunc($rightStr, '$newEscapeChar')); + $escapeFunc($rightStr, '$escapedEscapeChar')); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 55569b6..67e3bc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -37,6 +37,18 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) } + test("Code parts should be treated for escapes, but string inputs shouldn't be") { + val strlit = raw"\\" + val code = code"""String s = "foo\\bar" + "$strlit";""" + + val builtin = s"""String s = "foo\\bar" + "$strlit";""" + + val expected = raw"""String s = "foo\bar" + "\\";""" + + assert(builtin == expected) + assert(code.asInstanceOf[CodeBlock].toString == expected) + } + test("Block.stripMargin") { val isNull = JavaCode.isNullVariable("expr1_isNull") val value = JavaCode.variable("expr1", IntegerType) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
