This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9cfe96f  feat: Add dictionary binary to shuffle writer (#111)
9cfe96f is described below

commit 9cfe96f27647d7822e09a14115a02d677ab72f41
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Mon Feb 26 16:25:02 2024 -0800

    feat: Add dictionary binary to shuffle writer (#111)
    
    Native shuffle writer can write dictionary of string but dictionary of 
binary is not supported. We should add it.
---
 core/src/execution/datafusion/shuffle_writer.rs    | 157 +++++++++++++++++----
 core/src/execution/datafusion/spark_hash.rs        |   6 +
 .../org/apache/comet/exec/CometShuffleSuite.scala  |  19 +++
 3 files changed, 153 insertions(+), 29 deletions(-)

diff --git a/core/src/execution/datafusion/shuffle_writer.rs 
b/core/src/execution/datafusion/shuffle_writer.rs
index fc15fac..f836e3a 100644
--- a/core/src/execution/datafusion/shuffle_writer.rs
+++ b/core/src/execution/datafusion/shuffle_writer.rs
@@ -460,32 +460,32 @@ fn append_columns(
         };
     }
 
-    macro_rules! append_string_dict {
-        ($kt:ident) => {{
+    macro_rules! append_byte_dict {
+        ($kt:ident, $byte_type:ty, $array_type:ty) => {{
             match $kt.as_ref() {
                 DataType::Int8 => {
-                    append_dict!(Int8Type, StringDictionaryBuilder<Int8Type>, 
StringArray)
+                    append_dict!(Int8Type, 
GenericByteDictionaryBuilder<Int8Type, $byte_type>, $array_type)
                 }
                 DataType::Int16 => {
-                    append_dict!(Int16Type, 
StringDictionaryBuilder<Int16Type>, StringArray)
+                    append_dict!(Int16Type,  
GenericByteDictionaryBuilder<Int16Type, $byte_type>, $array_type)
                 }
                 DataType::Int32 => {
-                    append_dict!(Int32Type, 
StringDictionaryBuilder<Int32Type>, StringArray)
+                    append_dict!(Int32Type,  
GenericByteDictionaryBuilder<Int32Type, $byte_type>, $array_type)
                 }
                 DataType::Int64 => {
-                    append_dict!(Int64Type, 
StringDictionaryBuilder<Int64Type>, StringArray)
+                    append_dict!(Int64Type,  
GenericByteDictionaryBuilder<Int64Type, $byte_type>, $array_type)
                 }
                 DataType::UInt8 => {
-                    append_dict!(UInt8Type, 
StringDictionaryBuilder<UInt8Type>, StringArray)
+                    append_dict!(UInt8Type,  
GenericByteDictionaryBuilder<UInt8Type, $byte_type>, $array_type)
                 }
                 DataType::UInt16 => {
-                    append_dict!(UInt16Type, 
StringDictionaryBuilder<UInt16Type>, StringArray)
+                    append_dict!(UInt16Type, 
GenericByteDictionaryBuilder<UInt16Type, $byte_type>, $array_type)
                 }
                 DataType::UInt32 => {
-                    append_dict!(UInt32Type, 
StringDictionaryBuilder<UInt32Type>, StringArray)
+                    append_dict!(UInt32Type, 
GenericByteDictionaryBuilder<UInt32Type, $byte_type>, $array_type)
                 }
                 DataType::UInt64 => {
-                    append_dict!(UInt64Type, 
StringDictionaryBuilder<UInt64Type>, StringArray)
+                    append_dict!(UInt64Type, 
GenericByteDictionaryBuilder<UInt64Type, $byte_type>, $array_type)
                 }
                 _ => unreachable!("Unknown key type for dictionary"),
             }
@@ -522,7 +522,22 @@ fn append_columns(
         DataType::Dictionary(key_type, value_type)
             if matches!(value_type.as_ref(), DataType::Utf8) =>
         {
-            append_string_dict!(key_type)
+            append_byte_dict!(key_type, GenericStringType<i32>, StringArray)
+        }
+        DataType::Dictionary(key_type, value_type)
+            if matches!(value_type.as_ref(), DataType::LargeUtf8) =>
+        {
+            append_byte_dict!(key_type, GenericStringType<i64>, 
LargeStringArray)
+        }
+        DataType::Dictionary(key_type, value_type)
+            if matches!(value_type.as_ref(), DataType::Binary) =>
+        {
+            append_byte_dict!(key_type, GenericBinaryType<i32>, BinaryArray)
+        }
+        DataType::Dictionary(key_type, value_type)
+            if matches!(value_type.as_ref(), DataType::LargeBinary) =>
+        {
+            append_byte_dict!(key_type, GenericBinaryType<i64>, 
LargeBinaryArray)
         }
         DataType::Binary => append!(Binary),
         DataType::LargeBinary => append!(LargeBinary),
@@ -1028,7 +1043,7 @@ macro_rules! primitive_dict_builder_helper {
     };
 }
 
-macro_rules! string_dict_builder_inner_helper {
+macro_rules! byte_dict_builder_inner_helper {
     ($kt:ty, $capacity:ident, $builder:ident) => {
         Box::new($builder::<$kt>::with_capacity(
             $capacity,
@@ -1068,28 +1083,28 @@ fn make_dict_builder(datatype: &DataType, capacity: 
usize) -> Box<dyn ArrayBuild
         {
             match key_type.as_ref() {
                 DataType::Int8 => {
-                    string_dict_builder_inner_helper!(Int16Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(Int8Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::Int16 => {
-                    string_dict_builder_inner_helper!(Int16Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(Int16Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::Int32 => {
-                    string_dict_builder_inner_helper!(Int32Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(Int32Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::Int64 => {
-                    string_dict_builder_inner_helper!(Int64Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(Int64Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::UInt8 => {
-                    string_dict_builder_inner_helper!(UInt8Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(UInt8Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::UInt16 => {
-                    string_dict_builder_inner_helper!(UInt16Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(UInt16Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::UInt32 => {
-                    string_dict_builder_inner_helper!(UInt32Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(UInt32Type, capacity, 
StringDictionaryBuilder)
                 }
                 DataType::UInt64 => {
-                    string_dict_builder_inner_helper!(UInt64Type, capacity, 
StringDictionaryBuilder)
+                    byte_dict_builder_inner_helper!(UInt64Type, capacity, 
StringDictionaryBuilder)
                 }
                 _ => unreachable!(""),
             }
@@ -1098,47 +1113,47 @@ fn make_dict_builder(datatype: &DataType, capacity: 
usize) -> Box<dyn ArrayBuild
             if matches!(value_type.as_ref(), DataType::LargeUtf8) =>
         {
             match key_type.as_ref() {
-                DataType::Int8 => string_dict_builder_inner_helper!(
-                    Int16Type,
+                DataType::Int8 => byte_dict_builder_inner_helper!(
+                    Int8Type,
                     capacity,
                     LargeStringDictionaryBuilder
                 ),
-                DataType::Int16 => string_dict_builder_inner_helper!(
+                DataType::Int16 => byte_dict_builder_inner_helper!(
                     Int16Type,
                     capacity,
                     LargeStringDictionaryBuilder
                 ),
-                DataType::Int32 => string_dict_builder_inner_helper!(
+                DataType::Int32 => byte_dict_builder_inner_helper!(
                     Int32Type,
                     capacity,
                     LargeStringDictionaryBuilder
                 ),
-                DataType::Int64 => string_dict_builder_inner_helper!(
+                DataType::Int64 => byte_dict_builder_inner_helper!(
                     Int64Type,
                     capacity,
                     LargeStringDictionaryBuilder
                 ),
-                DataType::UInt8 => string_dict_builder_inner_helper!(
+                DataType::UInt8 => byte_dict_builder_inner_helper!(
                     UInt8Type,
                     capacity,
                     LargeStringDictionaryBuilder
                 ),
                 DataType::UInt16 => {
-                    string_dict_builder_inner_helper!(
+                    byte_dict_builder_inner_helper!(
                         UInt16Type,
                         capacity,
                         LargeStringDictionaryBuilder
                     )
                 }
                 DataType::UInt32 => {
-                    string_dict_builder_inner_helper!(
+                    byte_dict_builder_inner_helper!(
                         UInt32Type,
                         capacity,
                         LargeStringDictionaryBuilder
                     )
                 }
                 DataType::UInt64 => {
-                    string_dict_builder_inner_helper!(
+                    byte_dict_builder_inner_helper!(
                         UInt64Type,
                         capacity,
                         LargeStringDictionaryBuilder
@@ -1147,6 +1162,90 @@ fn make_dict_builder(datatype: &DataType, capacity: 
usize) -> Box<dyn ArrayBuild
                 _ => unreachable!(""),
             }
         }
+        DataType::Dictionary(key_type, value_type)
+            if matches!(value_type.as_ref(), DataType::Binary) =>
+        {
+            match key_type.as_ref() {
+                DataType::Int8 => {
+                    byte_dict_builder_inner_helper!(Int8Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::Int16 => {
+                    byte_dict_builder_inner_helper!(Int16Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::Int32 => {
+                    byte_dict_builder_inner_helper!(Int32Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::Int64 => {
+                    byte_dict_builder_inner_helper!(Int64Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::UInt8 => {
+                    byte_dict_builder_inner_helper!(UInt8Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::UInt16 => {
+                    byte_dict_builder_inner_helper!(UInt16Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::UInt32 => {
+                    byte_dict_builder_inner_helper!(UInt32Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                DataType::UInt64 => {
+                    byte_dict_builder_inner_helper!(UInt64Type, capacity, 
BinaryDictionaryBuilder)
+                }
+                _ => unreachable!(""),
+            }
+        }
+        DataType::Dictionary(key_type, value_type)
+            if matches!(value_type.as_ref(), DataType::LargeBinary) =>
+        {
+            match key_type.as_ref() {
+                DataType::Int8 => byte_dict_builder_inner_helper!(
+                    Int8Type,
+                    capacity,
+                    LargeBinaryDictionaryBuilder
+                ),
+                DataType::Int16 => byte_dict_builder_inner_helper!(
+                    Int16Type,
+                    capacity,
+                    LargeBinaryDictionaryBuilder
+                ),
+                DataType::Int32 => byte_dict_builder_inner_helper!(
+                    Int32Type,
+                    capacity,
+                    LargeBinaryDictionaryBuilder
+                ),
+                DataType::Int64 => byte_dict_builder_inner_helper!(
+                    Int64Type,
+                    capacity,
+                    LargeBinaryDictionaryBuilder
+                ),
+                DataType::UInt8 => byte_dict_builder_inner_helper!(
+                    UInt8Type,
+                    capacity,
+                    LargeBinaryDictionaryBuilder
+                ),
+                DataType::UInt16 => {
+                    byte_dict_builder_inner_helper!(
+                        UInt16Type,
+                        capacity,
+                        LargeBinaryDictionaryBuilder
+                    )
+                }
+                DataType::UInt32 => {
+                    byte_dict_builder_inner_helper!(
+                        UInt32Type,
+                        capacity,
+                        LargeBinaryDictionaryBuilder
+                    )
+                }
+                DataType::UInt64 => {
+                    byte_dict_builder_inner_helper!(
+                        UInt64Type,
+                        capacity,
+                        LargeBinaryDictionaryBuilder
+                    )
+                }
+                _ => unreachable!(""),
+            }
+        }
         t => panic!("Data type {t:?} is not currently supported"),
     }
 }
diff --git a/core/src/execution/datafusion/spark_hash.rs 
b/core/src/execution/datafusion/spark_hash.rs
index 0413e45..aeefccf 100644
--- a/core/src/execution/datafusion/spark_hash.rs
+++ b/core/src/execution/datafusion/spark_hash.rs
@@ -292,6 +292,12 @@ pub fn create_hashes<'a>(
             DataType::LargeUtf8 => {
                 hash_array!(LargeStringArray, col, hashes_buffer);
             }
+            DataType::Binary => {
+                hash_array!(BinaryArray, col, hashes_buffer);
+            }
+            DataType::LargeBinary => {
+                hash_array!(LargeBinaryArray, col, hashes_buffer);
+            }
             DataType::FixedSizeBinary(_) => {
                 hash_array!(FixedSizeBinaryArray, col, hashes_buffer);
             }
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
index acd424a..0d7c73d 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
@@ -63,6 +63,25 @@ abstract class CometShuffleSuiteBase extends CometTestBase 
with AdaptiveSparkPla
 
   import testImplicits._
 
+  test("Native shuffle with dictionary of binary") {
+    Seq("true", "false").foreach { dictionaryEnabled =>
+      withSQLConf(
+        CometConf.COMET_EXEC_ENABLED.key -> "true",
+        CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+        CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
+        withParquetTable(
+          (0 until 1000).map(i => (i % 5, (i % 5).toString.getBytes())),
+          "tbl",
+          dictionaryEnabled.toBoolean) {
+          val shuffled = sql("SELECT * FROM tbl").repartition(2, $"_2")
+
+          checkCometExchange(shuffled, 1, true)
+          checkSparkAnswer(shuffled)
+        }
+      }
+    }
+  }
+
   test("columnar shuffle on nested struct including nulls") {
     Seq(10, 201).foreach { numPartitions =>
       Seq("1.0", "10.0").foreach { ratio =>

Reply via email to