This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 245f0b8 support large-utf8 in groupby (#35)
245f0b8 is described below
commit 245f0b8a68c5763a236aef3e727f0502188d0bfa
Author: Ritchie Vink <[email protected]>
AuthorDate: Sat Apr 24 12:20:16 2021 +0200
support large-utf8 in groupby (#35)
* support large-utf8 in groupby
* add test
---
datafusion/src/execution/context.rs | 51 ++++++++++++++++++++++++++
datafusion/src/physical_plan/group_scalar.rs | 9 ++++-
datafusion/src/physical_plan/hash_aggregate.rs | 18 ++++++++-
datafusion/src/physical_plan/hash_join.rs | 3 ++
datafusion/src/physical_plan/type_coercion.rs | 2 +-
5 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/datafusion/src/execution/context.rs
b/datafusion/src/execution/context.rs
index c83ca4d..c394d38 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1647,6 +1647,57 @@ mod tests {
}
#[tokio::test]
+ async fn group_by_largeutf8() {
+ {
+ let mut ctx = ExecutionContext::new();
+
+ // input data looks like:
+ // A, 1
+ // B, 2
+ // A, 2
+ // A, 4
+ // C, 1
+ // A, 1
+
+ let str_array: LargeStringArray = vec!["A", "B", "A", "A", "C",
"A"]
+ .into_iter()
+ .map(Some)
+ .collect();
+ let str_array = Arc::new(str_array);
+
+ let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into();
+ let val_array = Arc::new(val_array);
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("str", str_array.data_type().clone(), false),
+ Field::new("val", val_array.data_type().clone(), false),
+ ]));
+
+ let batch =
+ RecordBatch::try_new(schema.clone(), vec![str_array,
val_array]).unwrap();
+
+ let provider = MemTable::try_new(schema.clone(),
vec![vec![batch]]).unwrap();
+ ctx.register_table("t", Arc::new(provider)).unwrap();
+
+ let results =
+ plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t
GROUP BY str")
+ .await
+ .expect("ran plan correctly");
+
+ let expected = vec![
+ "+-----+------------+",
+ "| str | COUNT(val) |",
+ "+-----+------------+",
+ "| A | 4 |",
+ "| B | 1 |",
+ "| C | 1 |",
+ "+-----+------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ }
+ }
+
+ #[tokio::test]
async fn group_by_dictionary() {
async fn run_test_case<K: ArrowDictionaryKeyType>() {
let mut ctx = ExecutionContext::new();
diff --git a/datafusion/src/physical_plan/group_scalar.rs
b/datafusion/src/physical_plan/group_scalar.rs
index f4987ae..943386d 100644
--- a/datafusion/src/physical_plan/group_scalar.rs
+++ b/datafusion/src/physical_plan/group_scalar.rs
@@ -37,6 +37,7 @@ pub(crate) enum GroupByScalar {
Int32(i32),
Int64(i64),
Utf8(Box<String>),
+ LargeUtf8(Box<String>),
Boolean(bool),
TimeMillisecond(i64),
TimeMicrosecond(i64),
@@ -74,6 +75,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
GroupByScalar::TimeNanosecond(*v)
}
ScalarValue::Utf8(Some(v)) =>
GroupByScalar::Utf8(Box::new(v.clone())),
+ ScalarValue::LargeUtf8(Some(v)) => {
+ GroupByScalar::LargeUtf8(Box::new(v.clone()))
+ }
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Boolean(None)
@@ -116,6 +120,7 @@ impl From<&GroupByScalar> for ScalarValue {
GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
+ GroupByScalar::LargeUtf8(v) =>
ScalarValue::LargeUtf8(Some(v.to_string())),
GroupByScalar::TimeMillisecond(v) => {
ScalarValue::TimestampMillisecond(Some(*v))
}
@@ -191,14 +196,14 @@ mod tests {
#[test]
fn from_scalar_unsupported() {
// Use any ScalarValue type not supported by GroupByScalar.
- let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
+ let scalar_value = ScalarValue::Binary(Some(vec![1, 2]));
let result = GroupByScalar::try_from(&scalar_value);
match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
- "Cannot convert a ScalarValue with associated DataType
LargeUtf8"
+ "Cannot convert a ScalarValue with associated DataType
Binary"
)
),
_ => panic!("Unexpected result"),
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs
b/datafusion/src/physical_plan/hash_aggregate.rs
index fd20b5c..fad4fa5 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -59,7 +59,8 @@ use ordered_float::OrderedFloat;
use pin_project_lite::pin_project;
use arrow::array::{
- TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray,
+ LargeStringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+ TimestampNanosecondArray,
};
use async_trait::async_trait;
@@ -540,6 +541,14 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec:
&mut Vec<u8>) -> Result<(
// store the string value
vec.extend_from_slice(value.as_bytes());
}
+ DataType::LargeUtf8 => {
+ let array =
col.as_any().downcast_ref::<LargeStringArray>().unwrap();
+ let value = array.value(row);
+ // store the size
+ vec.extend_from_slice(&value.len().to_le_bytes());
+ // store the string value
+ vec.extend_from_slice(value.as_bytes());
+ }
DataType::Date32 => {
let array = col.as_any().downcast_ref::<Date32Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
@@ -953,6 +962,9 @@ fn create_batch_from_map(
GroupByScalar::Utf8(str) => {
Arc::new(StringArray::from(vec![&***str]))
}
+ GroupByScalar::LargeUtf8(str) => {
+ Arc::new(LargeStringArray::from(vec![&***str]))
+ }
GroupByScalar::Boolean(b) =>
Arc::new(BooleanArray::from(vec![*b])),
GroupByScalar::TimeMillisecond(n) => {
Arc::new(TimestampMillisecondArray::from(vec![*n]))
@@ -1103,6 +1115,10 @@ fn create_group_by_value(col: &ArrayRef, row: usize) ->
Result<GroupByScalar> {
let array = col.as_any().downcast_ref::<StringArray>().unwrap();
Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
}
+ DataType::LargeUtf8 => {
+ let array =
col.as_any().downcast_ref::<LargeStringArray>().unwrap();
+ Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
+ }
DataType::Boolean => {
let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(GroupByScalar::Boolean(array.value(row)))
diff --git a/datafusion/src/physical_plan/hash_join.rs
b/datafusion/src/physical_plan/hash_join.rs
index 401fe65..eb2ec33 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -831,6 +831,9 @@ pub fn create_hashes<'a>(
DataType::Utf8 => {
hash_array!(StringArray, col, str, hashes_buffer,
random_state);
}
+ DataType::LargeUtf8 => {
+ hash_array!(LargeStringArray, col, str, hashes_buffer,
random_state);
+ }
_ => {
// This is internal because we should have caught this before.
return Err(DataFusionError::Internal(
diff --git a/datafusion/src/physical_plan/type_coercion.rs
b/datafusion/src/physical_plan/type_coercion.rs
index 24b51ba..d9f84e7 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -196,7 +196,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from:
&DataType) -> bool {
| Float64
),
Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from,
Timestamp(_, None)),
- Utf8 => true,
+ Utf8 | LargeUtf8 => true,
_ => false,
}
}