This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 0608361e5b14dcb08d631701894bd1cea7e39fdd
Author: Kris Mok <[email protected]>
AuthorDate: Wed Feb 12 15:19:16 2020 +0900

    [SPARK-30795][SQL] Spark SQL codegen's code() interpolator should treat 
escapes like Scala's StringContext.s()
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to make the `code` string interpolator treat escapes the 
same way as Scala's builtin `StringContext.s()` string interpolator. This will 
remove the need for an ugly workaround in `Like` expression's codegen.
    
    ### Why are the changes needed?
    
    The `code()` string interpolator in Spark SQL's code generator should treat 
escapes like Scala's builtin `StringContext.s()` interpolator, i.e. it should 
treat escapes in the code parts, and should not treat escapes in the input 
arguments.
    
    For example,
    ```scala
    val arg = "This is an argument."
    val str = s"This is string part 1. $arg This is string part 2."
    val code = code"This is string part 1. $arg This is string part 2."
    assert(code.toString == str)
    ```
    We should expect the `code()` interpolator to produce the same result as 
the `StringContext.s()` interpolator, where only escapes in the string parts 
should be treated, while the args should be kept verbatim.
    
    But in the current implementation, due to the eager folding of code parts 
and literal input args, the escape treatment is incorrectly done on both code 
parts and literal args.
    That causes a problem when an arg contains escape sequences and wants to 
preserve that in the final produced code string. For example, in `Like` 
expression's codegen, there's an ugly workaround for this bug:
    ```scala
          // We need double escape to avoid 
org.codehaus.commons.compiler.CompileException.
          // '\\' will cause exception 'Single quote must be backslash-escaped 
in character literal'.
          // '\"' will cause exception 'Line break in literal not allowed'.
          val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') {
            s"""\\\\\\$escapeChar"""
          } else {
            escapeChar
          }
    ```
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added a new unit test case in `CodeBlockSuite`.
    
    Closes #27544 from rednaxelafx/fix-code-string-interpolator.
    
    Authored-by: Kris Mok <[email protected]>
    Signed-off-by: HyukjinKwon <[email protected]>
---
 .../spark/sql/catalyst/expressions/codegen/javaCode.scala   | 13 +++++++++----
 .../spark/sql/catalyst/expressions/regexpExpressions.scala  | 13 ++++---------
 .../sql/catalyst/expressions/codegen/CodeBlockSuite.scala   | 12 ++++++++++++
 3 files changed, 25 insertions(+), 13 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index d9393b9..dff2589 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -223,6 +223,11 @@ object Block {
   implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ 
+ _)
 
   implicit class BlockHelper(val sc: StringContext) extends AnyVal {
+    /**
+     * A string interpolator that retains references to the `JavaCode` inputs, 
and behaves like
+     * the Scala builtin StringContext.s() interpolator otherwise, i.e. it 
will treat escapes in
+     * the code parts, and will not treat escapes in the input arguments.
+     */
     def code(args: Any*): Block = {
       sc.checkLengths(args)
       if (sc.parts.length == 0) {
@@ -250,7 +255,7 @@ object Block {
     val inputs = args.iterator
     val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
 
-    buf.append(strings.next)
+    buf.append(StringContext.treatEscapes(strings.next))
     while (strings.hasNext) {
       val input = inputs.next
       input match {
@@ -262,7 +267,7 @@ object Block {
         case _ =>
           buf.append(input)
       }
-      buf.append(strings.next)
+      buf.append(StringContext.treatEscapes(strings.next))
     }
     codeParts += buf.toString
 
@@ -286,10 +291,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: 
Seq[JavaCode]) extends
     val strings = codeParts.iterator
     val inputs = blockInputs.iterator
     val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
-    buf.append(StringContext.treatEscapes(strings.next))
+    buf.append(strings.next)
     while (strings.hasNext) {
       buf.append(inputs.next)
-      buf.append(StringContext.treatEscapes(strings.next))
+      buf.append(strings.next)
     }
     buf.toString
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 32a653d..ac620b1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -158,19 +158,14 @@ case class Like(left: Expression, right: Expression, 
escapeChar: Char)
     } else {
       val pattern = ctx.freshName("pattern")
       val rightStr = ctx.freshName("rightStr")
-      // We need double escape to avoid 
org.codehaus.commons.compiler.CompileException.
-      // '\\' will cause exception 'Single quote must be backslash-escaped in 
character literal'.
-      // '\"' will cause exception 'Line break in literal not allowed'.
-      val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') {
-        s"""\\\\\\$escapeChar"""
-      } else {
-        escapeChar
-      }
+      // We need to escape the escapeChar to make sure the generated code is 
valid.
+      // Otherwise we'll hit org.codehaus.commons.compiler.CompileException.
+      val escapedEscapeChar = StringEscapeUtils.escapeJava(escapeChar.toString)
       nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
         s"""
           String $rightStr = $eval2.toString();
           $patternClass $pattern = $patternClass.compile(
-            $escapeFunc($rightStr, '$newEscapeChar'));
+            $escapeFunc($rightStr, '$escapedEscapeChar'));
           ${ev.value} = $pattern.matcher($eval1.toString()).matches();
         """
       })
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
index 55569b6..67e3bc6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -37,6 +37,18 @@ class CodeBlockSuite extends SparkFunSuite {
     assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
   }
 
+  test("Code parts should be treated for escapes, but string inputs shouldn't 
be") {
+    val strlit = raw"\\"
+    val code = code"""String s = "foo\\bar" + "$strlit";"""
+
+    val builtin = s"""String s = "foo\\bar" + "$strlit";"""
+
+    val expected = raw"""String s = "foo\bar" + "\\";"""
+
+    assert(builtin == expected)
+    assert(code.asInstanceOf[CodeBlock].toString == expected)
+  }
+
   test("Block.stripMargin") {
     val isNull = JavaCode.isNullVariable("expr1_isNull")
     val value = JavaCode.variable("expr1", IntegerType)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to