This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 257a7883f215 [SPARK-48615][SQL] Perf improvement for parsing hex string
257a7883f215 is described below

commit 257a7883f2150e037eb05f8c7a111184103ad9a1
Author: Kent Yao <[email protected]>
AuthorDate: Mon Jun 17 09:56:05 2024 +0800

    [SPARK-48615][SQL] Perf improvement for parsing hex string
    
    ### What changes were proposed in this pull request?
    
    Currently, we use two heximal string parsing functions. One uses Apache 
Codecs Hex for X-prefixed lit parsing, and the other use builtin impl for unhex 
function. I did a benchmark for them comparing with the `java.util.HexFormat` 
which was introduced in JDK17.
    
    ```
    OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
    Apple M2 Max
    Cardinality 1000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    Apache                                             5050           5100      
    86          0.2        5050.1       1.0X
    Spark                                              3822           3840      
    30          0.3        3821.6       1.3X
    Java                                               2462           2522      
    87          0.4        2462.1       2.1X
    
    OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
    Apple M2 Max
    Cardinality 2000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    Apache                                            10020          10828      
  1154          0.2        5010.1       1.0X
    Spark                                              6875           6966      
   144          0.3        3437.7       1.5X
    Java                                               4999           5092      
    89          0.4        2499.3       2.0X
    
    OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
    Apple M2 Max
    Cardinality 4000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    Apache                                            20090          20433      
   433          0.2        5022.5       1.0X
    Spark                                             13389          13620      
   229          0.3        3347.2       1.5X
    Java                                              10023          10069      
    42          0.4        2505.6       2.0X
    
    OpenJDK 64-Bit Server VM 17.0.10+0 on Mac OS X 14.5
    Apple M2 Max
    Cardinality 8000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    Apache                                            40277          43453      
  2755          0.2        5034.7       1.0X
    Spark                                             27145          27380      
   311          0.3        3393.1       1.5X
    Java                                              19980          21198      
  1473          0.4        2497.5       2.0X
    ```
    
    The results indicate that the speed is Apache Codecs < builtin < Java, 
increasing by ~50%.
    
    In this PR, we replace these two with the Java 17 API
    
    ### Why are the changes needed?
    
    performance enhance
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    benchmarking
    
    existing unit tests in 
org.apache.spark.sql.catalyst.expressions.MathExpressionsSuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #46972 from yaooqinn/SPARK-48615.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
---
 .../benchmarks/HexBenchmark-jdk21-results.txt      | 14 ++++
 sql/catalyst/benchmarks/HexBenchmark-results.txt   | 14 ++++
 .../sql/catalyst/expressions/mathExpressions.scala | 94 +++++++++-------------
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  7 +-
 .../sql/catalyst/expressions/HexBenchmark.scala    | 90 +++++++++++++++++++++
 5 files changed, 158 insertions(+), 61 deletions(-)

diff --git a/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt 
b/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt
new file mode 100644
index 000000000000..afa3efa7a919
--- /dev/null
+++ b/sql/catalyst/benchmarks/HexBenchmark-jdk21-results.txt
@@ -0,0 +1,14 @@
+================================================================================================
+UnHex Comparison
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Cardinality 1000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+Common Codecs                                      4755           4766         
 13          0.2        4755.0       1.0X
+Java                                               4018           4048         
 45          0.2        4018.3       1.2X
+Spark                                              3473           3476         
  3          0.3        3472.8       1.4X
+Spark Binary                                       2625           2628         
  3          0.4        2624.6       1.8X
+
+
diff --git a/sql/catalyst/benchmarks/HexBenchmark-results.txt 
b/sql/catalyst/benchmarks/HexBenchmark-results.txt
new file mode 100644
index 000000000000..55a6a07fed40
--- /dev/null
+++ b/sql/catalyst/benchmarks/HexBenchmark-results.txt
@@ -0,0 +1,14 @@
+================================================================================================
+UnHex Comparison
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Cardinality 1000000:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+Common Codecs                                      4881           4897         
 25          0.2        4880.8       1.0X
+Java                                               4220           4226         
  9          0.2        4220.0       1.2X
+Spark                                              3954           3956         
  2          0.3        3954.5       1.2X
+Spark Binary                                       2738           2750         
 11          0.4        2737.9       1.8X
+
+
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 6801fc7c257c..20bedeb04098 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.{lang => jl}
+import java.util.HexFormat.fromHexDigit
 import java.util.Locale
 
 import org.apache.spark.QueryContext
@@ -1059,37 +1060,33 @@ object Hex {
   }
 
   def unhex(bytes: Array[Byte]): Array[Byte] = {
-    val out = new Array[Byte]((bytes.length + 1) >> 1)
-    var i = 0
-    var oddShift = 0
-    if ((bytes.length & 0x01) != 0) {
-      // padding with '0'
-      if (bytes(0) < 0) {
-        return null
-      }
-      val v = Hex.unhexDigits(bytes(0))
-      if (v == -1) {
-        return null
-      }
-      out(0) = v
-      i += 1
-      oddShift = 1
+    val length = bytes.length
+    if (length == 0) {
+      return Array.emptyByteArray
     }
-    // two characters form the hex value.
-    while (i < bytes.length) {
-      if (bytes(i) < 0 || bytes(i + 1) < 0) {
-        return null
+    if ((length & 0x1) != 0) {
+      // while length of bytes is odd, loop from the end to beginning w/o the 
head
+      val result = new Array[Byte](length / 2  + 1)
+      var i = result.length - 1
+      while (i > 0) {
+        result(i) = ((fromHexDigit(bytes(i * 2 - 1)) << 4) | 
fromHexDigit(bytes(i * 2))).toByte
+        i -= 1
       }
-      val first = Hex.unhexDigits(bytes(i))
-      val second = Hex.unhexDigits(bytes(i + 1))
-      if (first == -1 || second == -1) {
-        return null
+      // add it 'tailing' head
+      result(0) = fromHexDigit(bytes(0)).toByte
+      result
+    } else {
+      val result = new Array[Byte](length / 2)
+      var i = 0
+      while (i < result.length) {
+        result(i) = ((fromHexDigit(bytes(2 * i)) << 4) | fromHexDigit(bytes(2 
* i + 1))).toByte
+        i += 1
       }
-      out(i / 2 + oddShift) = (((first << 4) | second) & 0xFF).toByte
-      i += 2
+      result
     }
-    out
   }
+
+  def unhex(str: String): Array[Byte] = unhex(str.getBytes())
 }
 
 /**
@@ -1162,41 +1159,26 @@ case class Unhex(child: Expression, failOnError: 
Boolean = false)
   override def dataType: DataType = BinaryType
 
   protected override def nullSafeEval(num: Any): Any = {
-    val result = Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
-    if (failOnError && result == null) {
-      // The failOnError is set only from `ToBinary` function - hence we might 
safely set `hint`
-      // parameter to `try_to_binary`.
-      throw QueryExecutionErrors.invalidInputInConversionError(
-        BinaryType,
-        num.asInstanceOf[UTF8String],
-        UTF8String.fromString("HEX"),
-        "try_to_binary")
+    try {
+      Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
+    } catch {
+      case _: IllegalArgumentException if !failOnError => null
+      case _: IllegalArgumentException =>
+        throw QueryExecutionErrors.invalidInputInConversionError(
+          BinaryType,
+          num.asInstanceOf[UTF8String],
+          UTF8String.fromString("HEX"),
+          "try_to_binary")
     }
-    result
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    nullSafeCodeGen(ctx, ev, c => {
-      val hex = Hex.getClass.getName.stripSuffix("$")
-      val maybeFailOnErrorCode = if (failOnError) {
-        val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
-        s"""
-           |if (${ev.value} == null) {
-           |  throw QueryExecutionErrors.invalidInputInConversionError(
-           |    $binaryType,
-           |    $c,
-           |    UTF8String.fromString("HEX"),
-           |    "try_to_binary");
-           |}
-           |""".stripMargin
-      } else {
-        s"${ev.isNull} = ${ev.value} == null;"
-      }
-
+    val expr = ctx.addReferenceObj("this", this)
+    nullSafeCodeGen(ctx, ev, input => {
       s"""
-        ${ev.value} = $hex.unhex($c.getBytes());
-        $maybeFailOnErrorCode
-       """
+        ${ev.value} = (byte[]) $expr.nullSafeEval($input);
+        ${ev.isNull} = ${ev.value} == null;
+      """
     })
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 15c623235ccc..243ff9e8a6a4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -27,8 +27,6 @@ import scala.util.{Left, Right}
 import org.antlr.v4.runtime.{ParserRuleContext, Token}
 import org.antlr.v4.runtime.misc.Interval
 import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
-import org.apache.commons.codec.DecoderException
-import org.apache.commons.codec.binary.Hex
 
 import org.apache.spark.{SparkArithmeticException, SparkException, 
SparkIllegalArgumentException, SparkThrowable}
 import org.apache.spark.internal.{Logging, MDC}
@@ -2788,11 +2786,10 @@ class AstBuilder extends DataTypeAstBuilder with 
SQLConfHelper with Logging {
           Literal(interval, CalendarIntervalType)
         }
       case BINARY_HEX =>
-        val padding = if (value.length % 2 != 0) "0" else ""
         try {
-          Literal(Hex.decodeHex(padding + value), BinaryType)
+          Literal(Hex.unhex(value), BinaryType)
         } catch {
-          case e: DecoderException =>
+          case e: IllegalArgumentException =>
             val ex = QueryParsingErrors.cannotParseValueTypeError("X", value, 
ctx)
             ex.setStackTrace(e.getStackTrace)
             throw ex
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
new file mode 100644
index 000000000000..df3fcbb83906
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexBenchmark.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.Locale
+
+import org.apache.commons.codec.binary.{Hex => ApacheHex}
+
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Benchmark for hex
+ * To run this benchmark:
+ * {{{
+ *   1. without sbt:
+ *      bin/spark-submit --class <this class> --jars <spark core test jar> 
<spark catalyst test jar>
+ *   2. build/sbt "catalyst/Test/runMain <this class>"
+ *   3. generate result:
+ *      SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "catalyst/Test/runMain 
<this class>"
+ *      Results will be written to "benchmarks/HexBenchmark-results.txt".
+ * }}}
+ */
+object HexBenchmark extends BenchmarkBase {
+
+  private val hexStrings = {
+    var tmp = Seq("", "A", "AB", "ABC", "ABCD", "123ABCDEF")
+    tmp = tmp ++ tmp.map(_.toLowerCase(Locale.ROOT))
+    (2 to 4).foreach { i => tmp = tmp ++ tmp.map(x => x * i) }
+    tmp.map(UTF8String.fromString(_).toString)
+  }
+
+  private val hexBin = hexStrings.map(_.getBytes)
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    runBenchmark("UnHex Comparison") {
+      val N = 1_000_000
+      val benchmark = new Benchmark(s"Cardinality $N", N, 3, output = output)
+      benchmark.addCase("Common Codecs") { _ =>
+        (1 to N).foreach(_ => hexStrings.foreach(y => apacheDecodeHex(y)))
+      }
+
+      benchmark.addCase("Java") { _ =>
+        (1 to N).foreach(_ => hexStrings.foreach(y => javaUnhex(y)))
+      }
+
+      benchmark.addCase("Spark") { _ =>
+        (1 to N).foreach(_ => hexStrings.foreach(y => builtinUnHex(y)))
+      }
+
+      benchmark.addCase("Spark Binary") { _ =>
+        (1 to N).foreach(_ => hexBin.foreach(y => builtinUnHex(y)))
+      }
+      benchmark.run()
+    }
+  }
+
+  def apacheDecodeHex(value: String): Array[Byte] = {
+    val padding = if (value.length % 2 != 0) "0" else ""
+    ApacheHex.decodeHex(padding + value)
+  }
+
+  def builtinUnHex(value: String): Array[Byte] = {
+    Hex.unhex(value)
+  }
+
+  def builtinUnHex(value: Array[Byte]): Array[Byte] = {
+    Hex.unhex(value)
+  }
+
+  def javaUnhex(value: String) : Array[Byte] = {
+    val padding = if ((value.length & 0x1) != 0) "0" else ""
+    java.util.HexFormat.of().parseHex(padding + value)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to