This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 245f0b8  support large-utf8 in groupby (#35)
245f0b8 is described below

commit 245f0b8a68c5763a236aef3e727f0502188d0bfa
Author: Ritchie Vink <[email protected]>
AuthorDate: Sat Apr 24 12:20:16 2021 +0200

    support large-utf8 in groupby (#35)
    
    * support large-utf8 in groupby
    
    * add test
---
 datafusion/src/execution/context.rs            | 51 ++++++++++++++++++++++++++
 datafusion/src/physical_plan/group_scalar.rs   |  9 ++++-
 datafusion/src/physical_plan/hash_aggregate.rs | 18 ++++++++-
 datafusion/src/physical_plan/hash_join.rs      |  3 ++
 datafusion/src/physical_plan/type_coercion.rs  |  2 +-
 5 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/datafusion/src/execution/context.rs 
b/datafusion/src/execution/context.rs
index c83ca4d..c394d38 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1647,6 +1647,57 @@ mod tests {
     }
 
     #[tokio::test]
+    async fn group_by_largeutf8() {
+        {
+            let mut ctx = ExecutionContext::new();
+
+            // input data looks like:
+            // A, 1
+            // B, 2
+            // A, 2
+            // A, 4
+            // C, 1
+            // A, 1
+
+            let str_array: LargeStringArray = vec!["A", "B", "A", "A", "C", 
"A"]
+                .into_iter()
+                .map(Some)
+                .collect();
+            let str_array = Arc::new(str_array);
+
+            let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into();
+            let val_array = Arc::new(val_array);
+
+            let schema = Arc::new(Schema::new(vec![
+                Field::new("str", str_array.data_type().clone(), false),
+                Field::new("val", val_array.data_type().clone(), false),
+            ]));
+
+            let batch =
+                RecordBatch::try_new(schema.clone(), vec![str_array, 
val_array]).unwrap();
+
+            let provider = MemTable::try_new(schema.clone(), 
vec![vec![batch]]).unwrap();
+            ctx.register_table("t", Arc::new(provider)).unwrap();
+
+            let results =
+                plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t 
GROUP BY str")
+                    .await
+                    .expect("ran plan correctly");
+
+            let expected = vec![
+                "+-----+------------+",
+                "| str | COUNT(val) |",
+                "+-----+------------+",
+                "| A   | 4          |",
+                "| B   | 1          |",
+                "| C   | 1          |",
+                "+-----+------------+",
+            ];
+            assert_batches_sorted_eq!(expected, &results);
+        }
+    }
+
+    #[tokio::test]
     async fn group_by_dictionary() {
         async fn run_test_case<K: ArrowDictionaryKeyType>() {
             let mut ctx = ExecutionContext::new();
diff --git a/datafusion/src/physical_plan/group_scalar.rs 
b/datafusion/src/physical_plan/group_scalar.rs
index f4987ae..943386d 100644
--- a/datafusion/src/physical_plan/group_scalar.rs
+++ b/datafusion/src/physical_plan/group_scalar.rs
@@ -37,6 +37,7 @@ pub(crate) enum GroupByScalar {
     Int32(i32),
     Int64(i64),
     Utf8(Box<String>),
+    LargeUtf8(Box<String>),
     Boolean(bool),
     TimeMillisecond(i64),
     TimeMicrosecond(i64),
@@ -74,6 +75,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
                 GroupByScalar::TimeNanosecond(*v)
             }
             ScalarValue::Utf8(Some(v)) => 
GroupByScalar::Utf8(Box::new(v.clone())),
+            ScalarValue::LargeUtf8(Some(v)) => {
+                GroupByScalar::LargeUtf8(Box::new(v.clone()))
+            }
             ScalarValue::Float32(None)
             | ScalarValue::Float64(None)
             | ScalarValue::Boolean(None)
@@ -116,6 +120,7 @@ impl From<&GroupByScalar> for ScalarValue {
             GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
             GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
             GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
+            GroupByScalar::LargeUtf8(v) => 
ScalarValue::LargeUtf8(Some(v.to_string())),
             GroupByScalar::TimeMillisecond(v) => {
                 ScalarValue::TimestampMillisecond(Some(*v))
             }
@@ -191,14 +196,14 @@ mod tests {
     #[test]
     fn from_scalar_unsupported() {
         // Use any ScalarValue type not supported by GroupByScalar.
-        let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
+        let scalar_value = ScalarValue::Binary(Some(vec![1, 2]));
         let result = GroupByScalar::try_from(&scalar_value);
 
         match result {
             Err(DataFusionError::Internal(error_message)) => assert_eq!(
                 error_message,
                 String::from(
-                    "Cannot convert a ScalarValue with associated DataType 
LargeUtf8"
+                    "Cannot convert a ScalarValue with associated DataType 
Binary"
                 )
             ),
             _ => panic!("Unexpected result"),
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs 
b/datafusion/src/physical_plan/hash_aggregate.rs
index fd20b5c..fad4fa5 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -59,7 +59,8 @@ use ordered_float::OrderedFloat;
 use pin_project_lite::pin_project;
 
 use arrow::array::{
-    TimestampMicrosecondArray, TimestampMillisecondArray, 
TimestampNanosecondArray,
+    LargeStringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
+    TimestampNanosecondArray,
 };
 use async_trait::async_trait;
 
@@ -540,6 +541,14 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: 
&mut Vec<u8>) -> Result<(
             // store the string value
             vec.extend_from_slice(value.as_bytes());
         }
+        DataType::LargeUtf8 => {
+            let array = 
col.as_any().downcast_ref::<LargeStringArray>().unwrap();
+            let value = array.value(row);
+            // store the size
+            vec.extend_from_slice(&value.len().to_le_bytes());
+            // store the string value
+            vec.extend_from_slice(value.as_bytes());
+        }
         DataType::Date32 => {
             let array = col.as_any().downcast_ref::<Date32Array>().unwrap();
             vec.extend_from_slice(&array.value(row).to_le_bytes());
@@ -953,6 +962,9 @@ fn create_batch_from_map(
                     GroupByScalar::Utf8(str) => {
                         Arc::new(StringArray::from(vec![&***str]))
                     }
+                    GroupByScalar::LargeUtf8(str) => {
+                        Arc::new(LargeStringArray::from(vec![&***str]))
+                    }
                     GroupByScalar::Boolean(b) => 
Arc::new(BooleanArray::from(vec![*b])),
                     GroupByScalar::TimeMillisecond(n) => {
                         Arc::new(TimestampMillisecondArray::from(vec![*n]))
@@ -1103,6 +1115,10 @@ fn create_group_by_value(col: &ArrayRef, row: usize) -> 
Result<GroupByScalar> {
             let array = col.as_any().downcast_ref::<StringArray>().unwrap();
             Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
         }
+        DataType::LargeUtf8 => {
+            let array = 
col.as_any().downcast_ref::<LargeStringArray>().unwrap();
+            Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
+        }
         DataType::Boolean => {
             let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
             Ok(GroupByScalar::Boolean(array.value(row)))
diff --git a/datafusion/src/physical_plan/hash_join.rs 
b/datafusion/src/physical_plan/hash_join.rs
index 401fe65..eb2ec33 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -831,6 +831,9 @@ pub fn create_hashes<'a>(
             DataType::Utf8 => {
                 hash_array!(StringArray, col, str, hashes_buffer, 
random_state);
             }
+            DataType::LargeUtf8 => {
+                hash_array!(LargeStringArray, col, str, hashes_buffer, 
random_state);
+            }
             _ => {
                 // This is internal because we should have caught this before.
                 return Err(DataFusionError::Internal(
diff --git a/datafusion/src/physical_plan/type_coercion.rs 
b/datafusion/src/physical_plan/type_coercion.rs
index 24b51ba..d9f84e7 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -196,7 +196,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: 
&DataType) -> bool {
                 | Float64
         ),
         Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, 
Timestamp(_, None)),
-        Utf8 => true,
+        Utf8 | LargeUtf8 => true,
         _ => false,
     }
 }

Reply via email to