This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 9cfe96f feat: Add dictionary binary to shuffle writer (#111)
9cfe96f is described below
commit 9cfe96f27647d7822e09a14115a02d677ab72f41
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Feb 26 16:25:02 2024 -0800
feat: Add dictionary binary to shuffle writer (#111)
Native shuffle writer can write dictionary of string but dictionary of
binary is not supported. We should add it.
---
core/src/execution/datafusion/shuffle_writer.rs | 157 +++++++++++++++++----
core/src/execution/datafusion/spark_hash.rs | 6 +
.../org/apache/comet/exec/CometShuffleSuite.scala | 19 +++
3 files changed, 153 insertions(+), 29 deletions(-)
diff --git a/core/src/execution/datafusion/shuffle_writer.rs
b/core/src/execution/datafusion/shuffle_writer.rs
index fc15fac..f836e3a 100644
--- a/core/src/execution/datafusion/shuffle_writer.rs
+++ b/core/src/execution/datafusion/shuffle_writer.rs
@@ -460,32 +460,32 @@ fn append_columns(
};
}
- macro_rules! append_string_dict {
- ($kt:ident) => {{
+ macro_rules! append_byte_dict {
+ ($kt:ident, $byte_type:ty, $array_type:ty) => {{
match $kt.as_ref() {
DataType::Int8 => {
- append_dict!(Int8Type, StringDictionaryBuilder<Int8Type>,
StringArray)
+ append_dict!(Int8Type,
GenericByteDictionaryBuilder<Int8Type, $byte_type>, $array_type)
}
DataType::Int16 => {
- append_dict!(Int16Type,
StringDictionaryBuilder<Int16Type>, StringArray)
+ append_dict!(Int16Type,
GenericByteDictionaryBuilder<Int16Type, $byte_type>, $array_type)
}
DataType::Int32 => {
- append_dict!(Int32Type,
StringDictionaryBuilder<Int32Type>, StringArray)
+ append_dict!(Int32Type,
GenericByteDictionaryBuilder<Int32Type, $byte_type>, $array_type)
}
DataType::Int64 => {
- append_dict!(Int64Type,
StringDictionaryBuilder<Int64Type>, StringArray)
+ append_dict!(Int64Type,
GenericByteDictionaryBuilder<Int64Type, $byte_type>, $array_type)
}
DataType::UInt8 => {
- append_dict!(UInt8Type,
StringDictionaryBuilder<UInt8Type>, StringArray)
+ append_dict!(UInt8Type,
GenericByteDictionaryBuilder<UInt8Type, $byte_type>, $array_type)
}
DataType::UInt16 => {
- append_dict!(UInt16Type,
StringDictionaryBuilder<UInt16Type>, StringArray)
+ append_dict!(UInt16Type,
GenericByteDictionaryBuilder<UInt16Type, $byte_type>, $array_type)
}
DataType::UInt32 => {
- append_dict!(UInt32Type,
StringDictionaryBuilder<UInt32Type>, StringArray)
+ append_dict!(UInt32Type,
GenericByteDictionaryBuilder<UInt32Type, $byte_type>, $array_type)
}
DataType::UInt64 => {
- append_dict!(UInt64Type,
StringDictionaryBuilder<UInt64Type>, StringArray)
+ append_dict!(UInt64Type,
GenericByteDictionaryBuilder<UInt64Type, $byte_type>, $array_type)
}
_ => unreachable!("Unknown key type for dictionary"),
}
@@ -522,7 +522,22 @@ fn append_columns(
DataType::Dictionary(key_type, value_type)
if matches!(value_type.as_ref(), DataType::Utf8) =>
{
- append_string_dict!(key_type)
+ append_byte_dict!(key_type, GenericStringType<i32>, StringArray)
+ }
+ DataType::Dictionary(key_type, value_type)
+ if matches!(value_type.as_ref(), DataType::LargeUtf8) =>
+ {
+ append_byte_dict!(key_type, GenericStringType<i64>,
LargeStringArray)
+ }
+ DataType::Dictionary(key_type, value_type)
+ if matches!(value_type.as_ref(), DataType::Binary) =>
+ {
+ append_byte_dict!(key_type, GenericBinaryType<i32>, BinaryArray)
+ }
+ DataType::Dictionary(key_type, value_type)
+ if matches!(value_type.as_ref(), DataType::LargeBinary) =>
+ {
+ append_byte_dict!(key_type, GenericBinaryType<i64>,
LargeBinaryArray)
}
DataType::Binary => append!(Binary),
DataType::LargeBinary => append!(LargeBinary),
@@ -1028,7 +1043,7 @@ macro_rules! primitive_dict_builder_helper {
};
}
-macro_rules! string_dict_builder_inner_helper {
+macro_rules! byte_dict_builder_inner_helper {
($kt:ty, $capacity:ident, $builder:ident) => {
Box::new($builder::<$kt>::with_capacity(
$capacity,
@@ -1068,28 +1083,28 @@ fn make_dict_builder(datatype: &DataType, capacity:
usize) -> Box<dyn ArrayBuild
{
match key_type.as_ref() {
DataType::Int8 => {
- string_dict_builder_inner_helper!(Int16Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(Int8Type, capacity,
StringDictionaryBuilder)
}
DataType::Int16 => {
- string_dict_builder_inner_helper!(Int16Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(Int16Type, capacity,
StringDictionaryBuilder)
}
DataType::Int32 => {
- string_dict_builder_inner_helper!(Int32Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(Int32Type, capacity,
StringDictionaryBuilder)
}
DataType::Int64 => {
- string_dict_builder_inner_helper!(Int64Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(Int64Type, capacity,
StringDictionaryBuilder)
}
DataType::UInt8 => {
- string_dict_builder_inner_helper!(UInt8Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(UInt8Type, capacity,
StringDictionaryBuilder)
}
DataType::UInt16 => {
- string_dict_builder_inner_helper!(UInt16Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(UInt16Type, capacity,
StringDictionaryBuilder)
}
DataType::UInt32 => {
- string_dict_builder_inner_helper!(UInt32Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(UInt32Type, capacity,
StringDictionaryBuilder)
}
DataType::UInt64 => {
- string_dict_builder_inner_helper!(UInt64Type, capacity,
StringDictionaryBuilder)
+ byte_dict_builder_inner_helper!(UInt64Type, capacity,
StringDictionaryBuilder)
}
_ => unreachable!(""),
}
@@ -1098,47 +1113,47 @@ fn make_dict_builder(datatype: &DataType, capacity:
usize) -> Box<dyn ArrayBuild
if matches!(value_type.as_ref(), DataType::LargeUtf8) =>
{
match key_type.as_ref() {
- DataType::Int8 => string_dict_builder_inner_helper!(
- Int16Type,
+ DataType::Int8 => byte_dict_builder_inner_helper!(
+ Int8Type,
capacity,
LargeStringDictionaryBuilder
),
- DataType::Int16 => string_dict_builder_inner_helper!(
+ DataType::Int16 => byte_dict_builder_inner_helper!(
Int16Type,
capacity,
LargeStringDictionaryBuilder
),
- DataType::Int32 => string_dict_builder_inner_helper!(
+ DataType::Int32 => byte_dict_builder_inner_helper!(
Int32Type,
capacity,
LargeStringDictionaryBuilder
),
- DataType::Int64 => string_dict_builder_inner_helper!(
+ DataType::Int64 => byte_dict_builder_inner_helper!(
Int64Type,
capacity,
LargeStringDictionaryBuilder
),
- DataType::UInt8 => string_dict_builder_inner_helper!(
+ DataType::UInt8 => byte_dict_builder_inner_helper!(
UInt8Type,
capacity,
LargeStringDictionaryBuilder
),
DataType::UInt16 => {
- string_dict_builder_inner_helper!(
+ byte_dict_builder_inner_helper!(
UInt16Type,
capacity,
LargeStringDictionaryBuilder
)
}
DataType::UInt32 => {
- string_dict_builder_inner_helper!(
+ byte_dict_builder_inner_helper!(
UInt32Type,
capacity,
LargeStringDictionaryBuilder
)
}
DataType::UInt64 => {
- string_dict_builder_inner_helper!(
+ byte_dict_builder_inner_helper!(
UInt64Type,
capacity,
LargeStringDictionaryBuilder
@@ -1147,6 +1162,90 @@ fn make_dict_builder(datatype: &DataType, capacity:
usize) -> Box<dyn ArrayBuild
_ => unreachable!(""),
}
}
+ DataType::Dictionary(key_type, value_type)
+ if matches!(value_type.as_ref(), DataType::Binary) =>
+ {
+ match key_type.as_ref() {
+ DataType::Int8 => {
+ byte_dict_builder_inner_helper!(Int8Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::Int16 => {
+ byte_dict_builder_inner_helper!(Int16Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::Int32 => {
+ byte_dict_builder_inner_helper!(Int32Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::Int64 => {
+ byte_dict_builder_inner_helper!(Int64Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::UInt8 => {
+ byte_dict_builder_inner_helper!(UInt8Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::UInt16 => {
+ byte_dict_builder_inner_helper!(UInt16Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::UInt32 => {
+ byte_dict_builder_inner_helper!(UInt32Type, capacity,
BinaryDictionaryBuilder)
+ }
+ DataType::UInt64 => {
+ byte_dict_builder_inner_helper!(UInt64Type, capacity,
BinaryDictionaryBuilder)
+ }
+ _ => unreachable!(""),
+ }
+ }
+ DataType::Dictionary(key_type, value_type)
+ if matches!(value_type.as_ref(), DataType::LargeBinary) =>
+ {
+ match key_type.as_ref() {
+ DataType::Int8 => byte_dict_builder_inner_helper!(
+ Int8Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ ),
+ DataType::Int16 => byte_dict_builder_inner_helper!(
+ Int16Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ ),
+ DataType::Int32 => byte_dict_builder_inner_helper!(
+ Int32Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ ),
+ DataType::Int64 => byte_dict_builder_inner_helper!(
+ Int64Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ ),
+ DataType::UInt8 => byte_dict_builder_inner_helper!(
+ UInt8Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ ),
+ DataType::UInt16 => {
+ byte_dict_builder_inner_helper!(
+ UInt16Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ )
+ }
+ DataType::UInt32 => {
+ byte_dict_builder_inner_helper!(
+ UInt32Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ )
+ }
+ DataType::UInt64 => {
+ byte_dict_builder_inner_helper!(
+ UInt64Type,
+ capacity,
+ LargeBinaryDictionaryBuilder
+ )
+ }
+ _ => unreachable!(""),
+ }
+ }
t => panic!("Data type {t:?} is not currently supported"),
}
}
diff --git a/core/src/execution/datafusion/spark_hash.rs
b/core/src/execution/datafusion/spark_hash.rs
index 0413e45..aeefccf 100644
--- a/core/src/execution/datafusion/spark_hash.rs
+++ b/core/src/execution/datafusion/spark_hash.rs
@@ -292,6 +292,12 @@ pub fn create_hashes<'a>(
DataType::LargeUtf8 => {
hash_array!(LargeStringArray, col, hashes_buffer);
}
+ DataType::Binary => {
+ hash_array!(BinaryArray, col, hashes_buffer);
+ }
+ DataType::LargeBinary => {
+ hash_array!(LargeBinaryArray, col, hashes_buffer);
+ }
DataType::FixedSizeBinary(_) => {
hash_array!(FixedSizeBinaryArray, col, hashes_buffer);
}
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
index acd424a..0d7c73d 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala
@@ -63,6 +63,25 @@ abstract class CometShuffleSuiteBase extends CometTestBase
with AdaptiveSparkPla
import testImplicits._
+ test("Native shuffle with dictionary of binary") {
+ Seq("true", "false").foreach { dictionaryEnabled =>
+ withSQLConf(
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
+ withParquetTable(
+ (0 until 1000).map(i => (i % 5, (i % 5).toString.getBytes())),
+ "tbl",
+ dictionaryEnabled.toBoolean) {
+ val shuffled = sql("SELECT * FROM tbl").repartition(2, $"_2")
+
+ checkCometExchange(shuffled, 1, true)
+ checkSparkAnswer(shuffled)
+ }
+ }
+ }
+ }
+
test("columnar shuffle on nested struct including nulls") {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>