gabotechs commented on code in PR #14412:
URL: https://github.com/apache/datafusion/pull/14412#discussion_r2028725710


##########
datafusion/functions-aggregate/src/string_agg.rs:
##########
@@ -129,52 +172,326 @@ impl AggregateUDFImpl for StringAgg {
 
 #[derive(Debug)]
 pub(crate) struct StringAggAccumulator {
-    values: Option<String>,
+    array_agg_acc: Box<dyn Accumulator>,
     delimiter: String,
 }
 
 impl StringAggAccumulator {
-    pub fn new(delimiter: &str) -> Self {
+    pub fn new(array_agg_acc: Box<dyn Accumulator>, delimiter: &str) -> Self {
         Self {
-            values: None,
+            array_agg_acc,
             delimiter: delimiter.to_string(),
         }
     }
 }
 
 impl Accumulator for StringAggAccumulator {
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
-            .iter()
-            .filter_map(|v| v.as_ref().map(ToString::to_string))
-            .collect();
-        if !string_array.is_empty() {
-            let s = string_array.join(self.delimiter.as_str());
-            let v = self.values.get_or_insert("".to_string());
-            if !v.is_empty() {
-                v.push_str(self.delimiter.as_str());
+        self.array_agg_acc.update_batch(&filter_index(values, 1))
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let scalar = self.array_agg_acc.evaluate()?;
+
+        let ScalarValue::List(list) = scalar else {
+            return internal_err!("Expected a DataType::List while evaluating 
underlying ArrayAggAccumulator, but got {}", scalar.data_type());
+        };
+
+        let string_arr: Vec<_> = match list.value_type() {
+            DataType::LargeUtf8 => 
as_generic_string_array::<i64>(list.values())?
+                .iter()
+                .flatten()
+                .collect(),
+            DataType::Utf8 => as_generic_string_array::<i32>(list.values())?
+                .iter()
+                .flatten()
+                .collect(),
+            _ => {
+                return internal_err!(
+                    "Expected elements to of type Utf8 or LargeUtf8, but got 
{}",
+                    list.value_type()
+                )
             }
-            v.push_str(s.as_str());
+        };
+
+        if string_arr.is_empty() {
+            return Ok(ScalarValue::LargeUtf8(None));
         }
-        Ok(())
+
+        Ok(ScalarValue::LargeUtf8(Some(
+            string_arr.join(&self.delimiter),
+        )))
+    }
+
+    fn size(&self) -> usize {
+        size_of_val(self) - size_of_val(&self.array_agg_acc)
+            + self.array_agg_acc.size()
+            + self.delimiter.capacity()
+    }
+
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        self.array_agg_acc.state()
     }
 
     fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        self.update_batch(values)?;
+        self.array_agg_acc.merge_batch(values)
+    }
+}
+
+fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
+    values
+        .iter()
+        .enumerate()
+        .filter(|(i, _)| *i != index)
+        .map(|(_, v)| v)
+        .cloned()
+        .collect::<Vec<_>>()
+}
+
+#[cfg(test)]
+mod tests {

Review Comment:
   🤔 it looks there's still some cards that can be played for making 
compilation times faster then. 👍 thanks for all that info!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to