This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51f11033f86e [SPARK-48686][SQL] Improve performance of 
ParserUtils.unescapeSQLString
51f11033f86e is described below

commit 51f11033f86ea423b88f80690099b6384df172ac
Author: Josh Rosen <[email protected]>
AuthorDate: Tue Jun 25 18:13:06 2024 +0900

    [SPARK-48686][SQL] Improve performance of ParserUtils.unescapeSQLString
    
    ### What changes were proposed in this pull request?
    
    This PR implements multiple performance optimizations for 
`ParserUtils.unescapeSQLString`:
    
    1. Don't use regex: following https://github.com/apache/spark/pull/31362, 
the existing code uses regexes for parsing escaped character patterns. However, 
in the worst case (the expected common case of "no escaping needed") it will 
perform four regex match attempts per input character, resulting in significant 
garbage creation because the matchers aren't reused.
    2. Skip the StringBuilder allocation for raw strings and for strings that 
don't need any unescaping.
    3. Minor: use Java StringBuilder instead of the Scala version: this removes 
a layer of indirection and may benefit JIT (we've seen positive results in some 
scenarios from this type of switch).
    
    ### Why are the changes needed?
    
    unescapeSQLString showed up as a CPU and allocation hotspot in certain 
testing scenarios. See this flamegraph for an illustration of the relative 
costs of repeated regex matching in the old code:
    
    
![image](https://github.com/apache/spark/assets/50748/e045d9da-da0f-493c-a634-188acaeab1a9)
    
    The new code is almost arbitrarily faster (e.g. can show ~arbitrary 
relative speedups, depending on the choice of input) for strings that don't 
require unescaping. For strings that _do_ need escaping, I tested extreme cases 
where _every_ character needs escaping: in these cases I see ~10-20x speedups 
(depending on the type of escaping). The new code should be faster in every 
scenario.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Correctness is covered by existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47062 from JoshRosen/unescapeSQLString-optimizations.
    
    Authored-by: Josh Rosen <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/catalyst/util/SparkParserUtils.scala | 94 +++++++++++++++-------
 .../sql/catalyst/parser/ParserUtilsSuite.scala     | 12 +++
 2 files changed, 75 insertions(+), 31 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
index a4ce5fb12034..7597cb1d9087 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
@@ -16,8 +16,7 @@
  */
 package org.apache.spark.sql.catalyst.util
 
-import java.lang.{Long => JLong}
-import java.nio.CharBuffer
+import java.lang.{Long => JLong, StringBuilder => JStringBuilder}
 
 import org.antlr.v4.runtime.{ParserRuleContext, Token}
 import org.antlr.v4.runtime.misc.Interval
@@ -26,16 +25,10 @@ import org.antlr.v4.runtime.tree.TerminalNode
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
 
 trait SparkParserUtils {
-  val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
-  val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
-  val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
-  val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r
 
   /** Unescape backslash-escaped string enclosed by quotes. */
   def unescapeSQLString(b: String): String = {
-    val sb = new StringBuilder(b.length())
-
-    def appendEscapedChar(n: Char): Unit = {
+    def appendEscapedChar(n: Char, sb: JStringBuilder): Unit = {
       n match {
         case '0' => sb.append('\u0000')
         case 'b' => sb.append('\b')
@@ -50,22 +43,64 @@ trait SparkParserUtils {
       }
     }
 
-    if (b.startsWith("r") || b.startsWith("R")) {
+    def allCharsAreHex(s: String, start: Int, length: Int): Boolean = {
+      val end = start + length
+      var i = start
+      while (i < end) {
+        val c = s.charAt(i)
+        val cIsHex = (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 
'A' && c <= 'F')
+        if (!cIsHex) {
+          return false
+        }
+        i += 1
+      }
+      true
+    }
+
+    def isThreeDigitOctalEscape(s: String, start: Int): Boolean = {
+      val firstChar = s.charAt(start)
+      val secondChar = s.charAt(start + 1)
+      val thirdChar = s.charAt(start + 2)
+      (firstChar == '0' || firstChar == '1') &&
+        (secondChar >= '0' && secondChar <= '7') &&
+        (thirdChar >= '0' && thirdChar <= '7')
+    }
+
+    val isRawString = {
+      val firstChar = b.charAt(0)
+      firstChar == 'r' || firstChar == 'R'
+    }
+
+    if (isRawString) {
+      // Skip the 'r' or 'R' and the first and last quotations enclosing the 
string literal.
       b.substring(2, b.length - 1)
+    } else if (b.indexOf('\\') == -1) {
+      // Fast path for the common case where the string has no escaped 
characters,
+      // in which case we just skip the first and last quotations enclosing 
the string literal.
+      b.substring(1, b.length - 1)
     } else {
+      val sb = new JStringBuilder(b.length())
       // Skip the first and last quotations enclosing the string literal.
-      val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)
-
-      while (charBuffer.remaining() > 0) {
-        charBuffer match {
-          case U16_CHAR_PATTERN(cp) =>
+      var i = 1
+      val length = b.length - 1
+      while (i < length) {
+        val c = b.charAt(i)
+        if (c != '\\' || i + 1 == length) {
+          // Either a regular character or a backslash at the end of the 
string:
+          sb.append(c)
+          i += 1
+        } else {
+          // A backslash followed by at least one character:
+          i += 1
+          val cAfterBackslash = b.charAt(i)
+          if (cAfterBackslash == 'u' && i + 1 + 4 <= length && 
allCharsAreHex(b, i + 1, 4)) {
             // \u0000 style 16-bit unicode character literals.
-            sb.append(Integer.parseInt(cp, 16).toChar)
-            charBuffer.position(charBuffer.position() + 6)
-          case U32_CHAR_PATTERN(cp) =>
+            sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar)
+            i += 1 + 4
+          } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && 
allCharsAreHex(b, i + 1, 8)) {
             // \U00000000 style 32-bit unicode character literals.
             // Use Long to treat codePoint as unsigned in the range of 32-bit.
-            val codePoint = JLong.parseLong(cp, 16)
+            val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16)
             if (codePoint < 0x10000) {
               sb.append((codePoint & 0xFFFF).toChar)
             } else {
@@ -74,21 +109,18 @@ trait SparkParserUtils {
               sb.append(highSurrogate.toChar)
               sb.append(lowSurrogate.toChar)
             }
-            charBuffer.position(charBuffer.position() + 10)
-          case OCTAL_CHAR_PATTERN(cp) =>
+            i += 1 + 8
+          } else if (i + 3 <= length && isThreeDigitOctalEscape(b, i)) {
             // \000 style character literals.
-            sb.append(Integer.parseInt(cp, 8).toChar)
-            charBuffer.position(charBuffer.position() + 4)
-          case ESCAPED_CHAR_PATTERN(c) =>
-            // escaped character literals.
-            appendEscapedChar(c.charAt(0))
-            charBuffer.position(charBuffer.position() + 2)
-          case _ =>
-            // non-escaped character literals.
-            sb.append(charBuffer.get())
+            sb.append(Integer.parseInt(b, i, i + 3, 8).toChar)
+            i += 3
+          } else {
+            appendEscapedChar(cAfterBackslash, sb)
+            i += 1
+          }
         }
       }
-      sb.toString()
+      sb.toString
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
index d9f3067d30e5..218304db3d59 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -131,6 +131,18 @@ class ParserUtilsSuite extends SparkFunSuite {
         |cd\ef"""".stripMargin) ==
       """ab
         |cdef""".stripMargin)
+
+    // String with an invalid '\' as the last character.
+    assert(unescapeSQLString(""""abc\"""") == "abc\\")
+
+    // Strings containing invalid Unicode escapes with non-hex characters.
+    assert(unescapeSQLString("\"abc\\uXXXXa\"") == "abcuXXXXa")
+    assert(unescapeSQLString("\"abc\\uxxxxa\"") == "abcuxxxxa")
+    assert(unescapeSQLString("\"abc\\UXXXXXXXXa\"") == "abcUXXXXXXXXa")
+    assert(unescapeSQLString("\"abc\\Uxxxxxxxxa\"") == "abcUxxxxxxxxa")
+    // Guard against off-by-one errors in the "all chars are hex" routine:
+    assert(unescapeSQLString("\"abc\\uAAAXa\"") == "abcuAAAXa")
+
     // scalastyle:on nonascii
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to