This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 257a7883f215 [SPARK-48615][SQL] Perf improvement for parsing hex string
257a7883f215 is described below
commit 257a7883f2150e037eb05f8c7a111184103ad9a1
Author: Kent Yao <[email protected]>
AuthorDate: Mon Jun 17 09:56:05 2024 +0800
[SPARK-48615][SQL] Perf improvement for parsing hex string
### What changes were proposed in this pull request?
Currently, we use two heximal string parsing functions. One uses Apache
Codecs Hex for X-prefixed lit parsing, and the other use builtin impl for unhex
function. I did a benchmark for them comparing with the `java.util.HexFormat`
which was introduced in JDK17.
```
OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
Apple M2 Max
Cardinality 1000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Apache 5050 5100
86 0.2 5050.1 1.0X
Spark 3822 3840
30 0.3 3821.6 1.3X
Java 2462 2522
87 0.4 2462.1 2.1X
OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
Apple M2 Max
Cardinality 2000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Apache 10020 10828
1154 0.2 5010.1 1.0X
Spark 6875 6966
144 0.3 3437.7 1.5X
Java 4999 5092
89 0.4 2499.3 2.0X
OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
Apple M2 Max
Cardinality 4000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Apache 20090 20433
433 0.2 5022.5 1.0X
Spark 13389 13620
229 0.3 3347.2 1.5X
Java 10023 10069
42 0.4 2505.6 2.0X
OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
Apple M2 Max
Cardinality 8000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Apache 40277 43453
2755 0.2 5034.7 1.0X
Spark 27145 27380
311 0.3 3393.1 1.5X
Java 19980 21198
1473 0.4 2497.5 2.0X
```
The results indicate that the speed is Apache Codecs < builtin < Java,
increasing by ~50%.
In this PR, we replace these two with the Java 17 API
### Why are the changes needed?
performance enhance
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
benchmarking
existing unit tests in
org.apache.spark.sql.catalyst.expressions.MathExpressionsSuite
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46972 from yaooqinn/SPARK-48615.
Authored-by: Kent Yao <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../benchmarks/HexBenchmark-jdk21-results.txt | 14 ++++
sql/catalyst/benchmarks/HexBenchmark-results.txt | 14 ++++
.../sql/catalyst/expressions/mathExpressions.scala | 94 +++++++++-------------
.../spark/sql/catalyst/parser/AstBuilder.scala | 7 +-
.../sql/catalyst/expressions/HexBenchmark.scala | 90 +++++++++++++++++++++
5 files changed, 158 insertions(+), 61 deletions(-)
diff --git a/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt
b/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt
new file mode 100644
index 000000000000..afa3efa7a919
--- /dev/null
+++ b/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt
@@ -0,0 +1,14 @@
+================================================================================================
+UnHex Comparison
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Cardinality 1000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Common Codecs 4755 4766
13 0.2 4755.0 1.0X
+Java 4018 4048
45 0.2 4018.3 1.2X
+Spark 3473 3476
3 0.3 3472.8 1.4X
+Spark Binary 2625 2628
3 0.4 2624.6 1.8X
+
+
diff --git a/sql/catalyst/benchmarks/HexBenchmark-results.txt
b/sql/catalyst/benchmarks/HexBenchmark-results.txt
new file mode 100644
index 000000000000..55a6a07fed40
--- /dev/null
+++ b/sql/catalyst/benchmarks/HexBenchmark-results.txt
@@ -0,0 +1,14 @@
+================================================================================================
+UnHex Comparison
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Cardinality 1000000: Best Time(ms) Avg Time(ms)
Stdev(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------------------------------
+Common Codecs 4881 4897
25 0.2 4880.8 1.0X
+Java 4220 4226
9 0.2 4220.0 1.2X
+Spark 3954 3956
2 0.3 3954.5 1.2X
+Spark Binary 2738 2750
11 0.4 2737.9 1.8X
+
+
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 6801fc7c257c..20bedeb04098 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.{lang => jl}
+import java.util.HexFormat.fromHexDigit
import java.util.Locale
import org.apache.spark.QueryContext
@@ -1059,37 +1060,33 @@ object Hex {
}
def unhex(bytes: Array[Byte]): Array[Byte] = {
- val out = new Array[Byte]((bytes.length + 1) >> 1)
- var i = 0
- var oddShift = 0
- if ((bytes.length & 0x01) != 0) {
- // padding with '0'
- if (bytes(0) < 0) {
- return null
- }
- val v = Hex.unhexDigits(bytes(0))
- if (v == -1) {
- return null
- }
- out(0) = v
- i += 1
- oddShift = 1
+ val length = bytes.length
+ if (length == 0) {
+ return Array.emptyByteArray
}
- // two characters form the hex value.
- while (i < bytes.length) {
- if (bytes(i) < 0 || bytes(i + 1) < 0) {
- return null
+ if ((length & 0x1) != 0) {
+ // while length of bytes is odd, loop from the end to beginning w/o the
head
+ val result = new Array[Byte](length / 2 + 1)
+ var i = result.length - 1
+ while (i > 0) {
+ result(i) = ((fromHexDigit(bytes(i * 2 - 1)) << 4) |
fromHexDigit(bytes(i * 2))).toByte
+ i -= 1
}
- val first = Hex.unhexDigits(bytes(i))
- val second = Hex.unhexDigits(bytes(i + 1))
- if (first == -1 || second == -1) {
- return null
+ // add it 'tailing' head
+ result(0) = fromHexDigit(bytes(0)).toByte
+ result
+ } else {
+ val result = new Array[Byte](length / 2)
+ var i = 0
+ while (i < result.length) {
+ result(i) = ((fromHexDigit(bytes(2 * i)) << 4) | fromHexDigit(bytes(2
* i + 1))).toByte
+ i += 1
}
- out(i / 2 + oddShift) = (((first << 4) | second) & 0xFF).toByte
- i += 2
+ result
}
- out
}
+
+ def unhex(str: String): Array[Byte] = unhex(str.getBytes())
}
/**
@@ -1162,41 +1159,26 @@ case class Unhex(child: Expression, failOnError:
Boolean = false)
override def dataType: DataType = BinaryType
protected override def nullSafeEval(num: Any): Any = {
- val result = Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
- if (failOnError && result == null) {
- // The failOnError is set only from `ToBinary` function - hence we might
safely set `hint`
- // parameter to `try_to_binary`.
- throw QueryExecutionErrors.invalidInputInConversionError(
- BinaryType,
- num.asInstanceOf[UTF8String],
- UTF8String.fromString("HEX"),
- "try_to_binary")
+ try {
+ Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
+ } catch {
+ case _: IllegalArgumentException if !failOnError => null
+ case _: IllegalArgumentException =>
+ throw QueryExecutionErrors.invalidInputInConversionError(
+ BinaryType,
+ num.asInstanceOf[UTF8String],
+ UTF8String.fromString("HEX"),
+ "try_to_binary")
}
- result
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
- nullSafeCodeGen(ctx, ev, c => {
- val hex = Hex.getClass.getName.stripSuffix("$")
- val maybeFailOnErrorCode = if (failOnError) {
- val binaryType = ctx.addReferenceObj("to", BinaryType,
BinaryType.getClass.getName)
- s"""
- |if (${ev.value} == null) {
- | throw QueryExecutionErrors.invalidInputInConversionError(
- | $binaryType,
- | $c,
- | UTF8String.fromString("HEX"),
- | "try_to_binary");
- |}
- |""".stripMargin
- } else {
- s"${ev.isNull} = ${ev.value} == null;"
- }
-
+ val expr = ctx.addReferenceObj("this", this)
+ nullSafeCodeGen(ctx, ev, input => {
s"""
- ${ev.value} = $hex.unhex($c.getBytes());
- $maybeFailOnErrorCode
- """
+ ${ev.value} = (byte[]) $expr.nullSafeEval($input);
+ ${ev.isNull} = ${ev.value} == null;
+ """
})
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 15c623235ccc..243ff9e8a6a4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -27,8 +27,6 @@ import scala.util.{Left, Right}
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
-import org.apache.commons.codec.DecoderException
-import org.apache.commons.codec.binary.Hex
import org.apache.spark.{SparkArithmeticException, SparkException,
SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.internal.{Logging, MDC}
@@ -2788,11 +2786,10 @@ class AstBuilder extends DataTypeAstBuilder with
SQLConfHelper with Logging {
Literal(interval, CalendarIntervalType)
}
case BINARY_HEX =>
- val padding = if (value.length % 2 != 0) "0" else ""
try {
- Literal(Hex.decodeHex(padding + value), BinaryType)
+ Literal(Hex.unhex(value), BinaryType)
} catch {
- case e: DecoderException =>
+ case e: IllegalArgumentException =>
val ex = QueryParsingErrors.cannotParseValueTypeError("X", value,
ctx)
ex.setStackTrace(e.getStackTrace)
throw ex
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
new file mode 100644
index 000000000000..df3fcbb83906
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.Locale
+
+import org.apache.commons.codec.binary.{Hex => ApacheHex}
+
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Benchmark for hex
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class> --jars <spark core test jar>
<spark catalyst test jar>
+ * 2. build/sbt "catalyst/Test/runMain <this class>"
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/Test/runMain
<this class>"
+ * Results will be written to "benchmarks/HexBenchmark-results.txt".
+ * }}}
+ */
+object HexBenchmark extends BenchmarkBase {
+
+ private val hexStrings = {
+ var tmp = Seq("", "A", "AB", "ABC", "ABCD", "123ABCDEF")
+ tmp = tmp ++ tmp.map(_.toLowerCase(Locale.ROOT))
+ (2 to 4).foreach { i => tmp = tmp ++ tmp.map(x => x * i) }
+ tmp.map(UTF8String.fromString(_).toString)
+ }
+
+ private val hexBin = hexStrings.map(_.getBytes)
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("UnHex Comparison") {
+ val N = 1_000_000
+ val benchmark = new Benchmark(s"Cardinality $N", N, 3, output = output)
+ benchmark.addCase("Common Codecs") { _ =>
+ (1 to N).foreach(_ => hexStrings.foreach(y => apacheDecodeHex(y)))
+ }
+
+ benchmark.addCase("Java") { _ =>
+ (1 to N).foreach(_ => hexStrings.foreach(y => javaUnhex(y)))
+ }
+
+ benchmark.addCase("Spark") { _ =>
+ (1 to N).foreach(_ => hexStrings.foreach(y => builtinUnHex(y)))
+ }
+
+ benchmark.addCase("Spark Binary") { _ =>
+ (1 to N).foreach(_ => hexBin.foreach(y => builtinUnHex(y)))
+ }
+ benchmark.run()
+ }
+ }
+
+ def apacheDecodeHex(value: String): Array[Byte] = {
+ val padding = if (value.length % 2 != 0) "0" else ""
+ ApacheHex.decodeHex(padding + value)
+ }
+
+ def builtinUnHex(value: String): Array[Byte] = {
+ Hex.unhex(value)
+ }
+
+ def builtinUnHex(value: Array[Byte]): Array[Byte] = {
+ Hex.unhex(value)
+ }
+
+ def javaUnhex(value: String) : Array[Byte] = {
+ val padding = if ((value.length & 0x1) != 0) "0" else ""
+ java.util.HexFormat.of().parseHex(padding + value)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]