This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new df13ca05c475 [SPARK-48735][SQL] Performance Improvement for BIN 
function
df13ca05c475 is described below

commit df13ca05c475e98bf5c218a4503513065611a47f
Author: Kent Yao <[email protected]>
AuthorDate: Thu Jun 27 21:39:06 2024 +0800

    [SPARK-48735][SQL] Performance Improvement for BIN function
    
    ### What changes were proposed in this pull request?
    
    This PR implemented a long-to-binary form UTF8String method directly to 
improve the performance of the BIN function. It omits the procedure of 
encoding/decoding and array copying.
    
    ### Why are the changes needed?
    
    performance improvement
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    - new unit tests
    - offline benchmarking ~2x
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #47119 from yaooqinn/SPARK-48735.
    
    Authored-by: Kent Yao <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
---
 .../org/apache/spark/unsafe/types/UTF8String.java  | 19 +++++++++++++
 .../apache/spark/unsafe/types/UTF8StringSuite.java | 17 ++++++++++++
 .../sql/catalyst/expressions/mathExpressions.scala |  8 ++----
 .../sql-tests/analyzer-results/ansi/math.sql.out   | 28 +++++++++++++++++++
 .../sql-tests/analyzer-results/math.sql.out        | 28 +++++++++++++++++++
 .../src/test/resources/sql-tests/inputs/math.sql   |  5 ++++
 .../resources/sql-tests/results/ansi/math.sql.out  | 32 ++++++++++++++++++++++
 .../test/resources/sql-tests/results/math.sql.out  | 32 ++++++++++++++++++++++
 8 files changed, 164 insertions(+), 5 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 12a7b06232ee..49d3088f8a2f 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -102,6 +102,8 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
 
   private static final UTF8String COMMA_UTF8 = UTF8String.fromString(",");
   public static final UTF8String EMPTY_UTF8 = UTF8String.fromString("");
+  public static final UTF8String ZERO_UTF8 = UTF8String.fromString("0");
+
 
   /**
    * Creates an UTF8String from byte array, which should be encoded in UTF-8.
@@ -1867,4 +1869,21 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     in.read((byte[]) base);
   }
 
+  /**
+   * Convert a long value to its binary format stripping leading zeros.
+   */
+  public static UTF8String toBinaryString(long val) {
+    int zeros = Long.numberOfLeadingZeros(val);
+    if (zeros == Long.SIZE) {
+      return UTF8String.ZERO_UTF8;
+    } else {
+      int length = Long.SIZE - zeros;
+      byte[] bytes = new byte[length];
+      do {
+        bytes[--length] = (byte) ((val & 0x1) == 1 ? '1': '0');
+        val >>>= 1;
+      } while (length > 0);
+      return fromBytes(bytes);
+    }
+  }
 }
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index f9b351697e8b..07793a24e5ee 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -1110,4 +1110,21 @@ public class UTF8StringSuite {
     testIsValid("0x9C 0x76 0x17", "0xEF 0xBF 0xBD 0x76 0x17");
   }
 
+  @Test
+  public void toBinaryString() {
+    assertEquals(ZERO_UTF8, UTF8String.toBinaryString(0));
+    assertEquals(UTF8String.fromString("1"), UTF8String.toBinaryString(1));
+    assertEquals(UTF8String.fromString("10"), UTF8String.toBinaryString(2));
+    assertEquals(UTF8String.fromString("100"), UTF8String.toBinaryString(4));
+    assertEquals(UTF8String.fromString("111"), UTF8String.toBinaryString(7));
+    assertEquals(
+      
UTF8String.fromString("1111111111111111111111111111111111111111111111111111111111110011"),
+      UTF8String.toBinaryString(-13));
+    assertEquals(
+      
UTF8String.fromString("1000000000000000000000000000000000000000000000000000000000000000"),
+      UTF8String.toBinaryString(Long.MIN_VALUE));
+    assertEquals(
+      
UTF8String.fromString("111111111111111111111111111111111111111111111111111111111111111"),
+      UTF8String.toBinaryString(Long.MAX_VALUE));
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 5981b42aead8..00274a16b888 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1008,11 +1008,10 @@ case class Bin(child: Expression)
   override def dataType: DataType = SQLConf.get.defaultStringType
 
   protected override def nullSafeEval(input: Any): Any =
-    UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
+    UTF8String.toBinaryString(input.asInstanceOf[Long])
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, (c) =>
-      s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
+    defineCodeGen(ctx, ev, c => s"UTF8String.toBinaryString($c)")
   }
 
   override protected def withNewChildInternal(newChild: Expression): Bin = 
copy(child = newChild)
@@ -1021,7 +1020,6 @@ case class Bin(child: Expression)
 object Hex {
   private final val hexDigits =
     Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 
'C', 'D', 'E', 'F')
-  private final val ZERO_UTF8 = UTF8String.fromBytes(Array[Byte]('0'))
 
   // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15
   val unhexDigits = {
@@ -1053,7 +1051,7 @@ object Hex {
 
   def hex(num: Long): UTF8String = {
     val zeros = jl.Long.numberOfLeadingZeros(num)
-    if (zeros == jl.Long.SIZE) return ZERO_UTF8
+    if (zeros == jl.Long.SIZE) return UTF8String.ZERO_UTF8
     val len = (jl.Long.SIZE - zeros + 3) / 4
     var numBuf = num
     val value = new Array[Byte](len)
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out
index 7eb7fcff356a..8d59b678e92f 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/math.sql.out
@@ -431,3 +431,31 @@ SELECT conv('-9223372036854775807', 36, 10)
 -- !query analysis
 Project [conv(-9223372036854775807, 36, 10, true) AS 
conv(-9223372036854775807, 36, 10)#x]
 +- OneRowRelation
+
+
+-- !query
+SELECT BIN(0)
+-- !query analysis
+Project [bin(cast(0 as bigint)) AS bin(0)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25)
+-- !query analysis
+Project [bin(cast(25 as bigint)) AS bin(25)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25L)
+-- !query analysis
+Project [bin(25) AS bin(25)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25.5)
+-- !query analysis
+Project [bin(cast(25.5 as bigint)) AS bin(25.5)#x]
++- OneRowRelation
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out
index e4dd1994b2c9..0d9b9267cd08 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/math.sql.out
@@ -431,3 +431,31 @@ SELECT conv('-9223372036854775807', 36, 10)
 -- !query analysis
 Project [conv(-9223372036854775807, 36, 10, false) AS 
conv(-9223372036854775807, 36, 10)#x]
 +- OneRowRelation
+
+
+-- !query
+SELECT BIN(0)
+-- !query analysis
+Project [bin(cast(0 as bigint)) AS bin(0)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25)
+-- !query analysis
+Project [bin(cast(25 as bigint)) AS bin(25)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25L)
+-- !query analysis
+Project [bin(25) AS bin(25)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT BIN(25.5)
+-- !query analysis
+Project [bin(cast(25.5 as bigint)) AS bin(25.5)#x]
++- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql 
b/sql/core/src/test/resources/sql-tests/inputs/math.sql
index 96fb0eeef7ac..398a8b3290b1 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/math.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql
@@ -77,3 +77,8 @@ SELECT conv('9223372036854775808', 10, 16);
 SELECT conv('92233720368547758070', 10, 16);
 SELECT conv('9223372036854775807', 36, 10);
 SELECT conv('-9223372036854775807', 36, 10);
+
+SELECT BIN(0);
+SELECT BIN(25);
+SELECT BIN(25L);
+SELECT BIN(25.5);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
index 8cd1536d7f72..9b886218f3ad 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
@@ -797,3 +797,35 @@ org.apache.spark.SparkArithmeticException
     "fragment" : "conv('-9223372036854775807', 36, 10)"
   } ]
 }
+
+
+-- !query
+SELECT BIN(0)
+-- !query schema
+struct<bin(0):string>
+-- !query output
+0
+
+
+-- !query
+SELECT BIN(25)
+-- !query schema
+struct<bin(25):string>
+-- !query output
+11001
+
+
+-- !query
+SELECT BIN(25L)
+-- !query schema
+struct<bin(25):string>
+-- !query output
+11001
+
+
+-- !query
+SELECT BIN(25.5)
+-- !query schema
+struct<bin(25.5):string>
+-- !query output
+11001
diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out 
b/sql/core/src/test/resources/sql-tests/results/math.sql.out
index d3df5cb93357..88a857a00f0f 100644
--- a/sql/core/src/test/resources/sql-tests/results/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out
@@ -493,3 +493,35 @@ SELECT conv('-9223372036854775807', 36, 10)
 struct<conv(-9223372036854775807, 36, 10):string>
 -- !query output
 18446744073709551615
+
+
+-- !query
+SELECT BIN(0)
+-- !query schema
+struct<bin(0):string>
+-- !query output
+0
+
+
+-- !query
+SELECT BIN(25)
+-- !query schema
+struct<bin(25):string>
+-- !query output
+11001
+
+
+-- !query
+SELECT BIN(25L)
+-- !query schema
+struct<bin(25):string>
+-- !query output
+11001
+
+
+-- !query
+SELECT BIN(25.5)
+-- !query schema
+struct<bin(25.5):string>
+-- !query output
+11001


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to