This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 51f11033f86e [SPARK-48686][SQL] Improve performance of
ParserUtils.unescapeSQLString
51f11033f86e is described below
commit 51f11033f86ea423b88f80690099b6384df172ac
Author: Josh Rosen <[email protected]>
AuthorDate: Tue Jun 25 18:13:06 2024 +0900
[SPARK-48686][SQL] Improve performance of ParserUtils.unescapeSQLString
### What changes were proposed in this pull request?
This PR implements multiple performance optimizations for
`ParserUtils.unescapeSQLString`:
1. Don't use regex: following https://github.com/apache/spark/pull/31362,
the existing code uses regexes for parsing escaped character patterns. However,
in the worst case (the expected common case of "no escaping needed") it will
perform four regex match attempts per input character, resulting in significant
garbage creation because the matchers aren't reused.
2. Skip the StringBuilder allocation for raw strings and for strings that
don't need any unescaping.
3. Minor: use Java StringBuilder instead of the Scala version: this removes
a layer of indirection and may benefit JIT (we've seen positive results in some
scenarios from this type of switch).
### Why are the changes needed?
unescapeSQLString showed up as a CPU and allocation hotspot in certain
testing scenarios. See this flamegraph for an illustration of the relative
costs of repeated regex matching in the old code:

The new code is almost arbitrarily faster (e.g. can show ~arbitrary
relative speedups, depending on the choice of input) for strings that don't
require unescaping. For strings that _do_ need escaping, I tested extreme cases
where _every_ character needs escaping: in these cases I see ~10-20x speedups
(depending on the type of escaping). The new code should be faster in every
scenario.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Correctness is covered by existing unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47062 from JoshRosen/unescapeSQLString-optimizations.
Authored-by: Josh Rosen <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/catalyst/util/SparkParserUtils.scala | 94 +++++++++++++++-------
.../sql/catalyst/parser/ParserUtilsSuite.scala | 12 +++
2 files changed, 75 insertions(+), 31 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
index a4ce5fb12034..7597cb1d9087 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
@@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.catalyst.util
-import java.lang.{Long => JLong}
-import java.nio.CharBuffer
+import java.lang.{Long => JLong, StringBuilder => JStringBuilder}
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
@@ -26,16 +25,10 @@ import org.antlr.v4.runtime.tree.TerminalNode
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
trait SparkParserUtils {
- val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
- val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
- val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
- val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r
/** Unescape backslash-escaped string enclosed by quotes. */
def unescapeSQLString(b: String): String = {
- val sb = new StringBuilder(b.length())
-
- def appendEscapedChar(n: Char): Unit = {
+ def appendEscapedChar(n: Char, sb: JStringBuilder): Unit = {
n match {
case '0' => sb.append('\u0000')
case 'b' => sb.append('\b')
@@ -50,22 +43,64 @@ trait SparkParserUtils {
}
}
- if (b.startsWith("r") || b.startsWith("R")) {
+ def allCharsAreHex(s: String, start: Int, length: Int): Boolean = {
+ val end = start + length
+ var i = start
+ while (i < end) {
+ val c = s.charAt(i)
+ val cIsHex = (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >=
'A' && c <= 'F')
+ if (!cIsHex) {
+ return false
+ }
+ i += 1
+ }
+ true
+ }
+
+ def isThreeDigitOctalEscape(s: String, start: Int): Boolean = {
+ val firstChar = s.charAt(start)
+ val secondChar = s.charAt(start + 1)
+ val thirdChar = s.charAt(start + 2)
+ (firstChar == '0' || firstChar == '1') &&
+ (secondChar >= '0' && secondChar <= '7') &&
+ (thirdChar >= '0' && thirdChar <= '7')
+ }
+
+ val isRawString = {
+ val firstChar = b.charAt(0)
+ firstChar == 'r' || firstChar == 'R'
+ }
+
+ if (isRawString) {
+ // Skip the 'r' or 'R' and the first and last quotations enclosing the
string literal.
b.substring(2, b.length - 1)
+ } else if (b.indexOf('\\') == -1) {
+ // Fast path for the common case where the string has no escaped
characters,
+ // in which case we just skip the first and last quotations enclosing
the string literal.
+ b.substring(1, b.length - 1)
} else {
+ val sb = new JStringBuilder(b.length())
// Skip the first and last quotations enclosing the string literal.
- val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)
-
- while (charBuffer.remaining() > 0) {
- charBuffer match {
- case U16_CHAR_PATTERN(cp) =>
+ var i = 1
+ val length = b.length - 1
+ while (i < length) {
+ val c = b.charAt(i)
+ if (c != '\\' || i + 1 == length) {
+ // Either a regular character or a backslash at the end of the
string:
+ sb.append(c)
+ i += 1
+ } else {
+ // A backslash followed by at least one character:
+ i += 1
+ val cAfterBackslash = b.charAt(i)
+ if (cAfterBackslash == 'u' && i + 1 + 4 <= length &&
allCharsAreHex(b, i + 1, 4)) {
// \u0000 style 16-bit unicode character literals.
- sb.append(Integer.parseInt(cp, 16).toChar)
- charBuffer.position(charBuffer.position() + 6)
- case U32_CHAR_PATTERN(cp) =>
+ sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar)
+ i += 1 + 4
+ } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length &&
allCharsAreHex(b, i + 1, 8)) {
// \U00000000 style 32-bit unicode character literals.
// Use Long to treat codePoint as unsigned in the range of 32-bit.
- val codePoint = JLong.parseLong(cp, 16)
+ val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16)
if (codePoint < 0x10000) {
sb.append((codePoint & 0xFFFF).toChar)
} else {
@@ -74,21 +109,18 @@ trait SparkParserUtils {
sb.append(highSurrogate.toChar)
sb.append(lowSurrogate.toChar)
}
- charBuffer.position(charBuffer.position() + 10)
- case OCTAL_CHAR_PATTERN(cp) =>
+ i += 1 + 8
+ } else if (i + 3 <= length && isThreeDigitOctalEscape(b, i)) {
// \000 style character literals.
- sb.append(Integer.parseInt(cp, 8).toChar)
- charBuffer.position(charBuffer.position() + 4)
- case ESCAPED_CHAR_PATTERN(c) =>
- // escaped character literals.
- appendEscapedChar(c.charAt(0))
- charBuffer.position(charBuffer.position() + 2)
- case _ =>
- // non-escaped character literals.
- sb.append(charBuffer.get())
+ sb.append(Integer.parseInt(b, i, i + 3, 8).toChar)
+ i += 3
+ } else {
+ appendEscapedChar(cAfterBackslash, sb)
+ i += 1
+ }
}
}
- sb.toString()
+ sb.toString
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
index d9f3067d30e5..218304db3d59 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -131,6 +131,18 @@ class ParserUtilsSuite extends SparkFunSuite {
|cd\ef"""".stripMargin) ==
"""ab
|cdef""".stripMargin)
+
+ // String with an invalid '\' as the last character.
+ assert(unescapeSQLString(""""abc\"""") == "abc\\")
+
+ // Strings containing invalid Unicode escapes with non-hex characters.
+ assert(unescapeSQLString("\"abc\\uXXXXa\"") == "abcuXXXXa")
+ assert(unescapeSQLString("\"abc\\uxxxxa\"") == "abcuxxxxa")
+ assert(unescapeSQLString("\"abc\\UXXXXXXXXa\"") == "abcUXXXXXXXXa")
+ assert(unescapeSQLString("\"abc\\Uxxxxxxxxa\"") == "abcUxxxxxxxxa")
+ // Guard against off-by-one errors in the "all chars are hex" routine:
+ assert(unescapeSQLString("\"abc\\uAAAXa\"") == "abcuAAAXa")
+
// scalastyle:on nonascii
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]