alamb commented on code in PR #23309:
URL: https://github.com/apache/datafusion/pull/23309#discussion_r3521848125


##########
datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs:
##########
@@ -85,29 +114,63 @@ pub(in crate::aggregates) struct 
AggregateHashTable<AggrMode> {
 
     /// Lifecycle-specific state: building stage / outputting stage.
     pub(super) state: AggregateHashTableState,
-
-    pub(super) _mode: PhantomData<AggrMode>,
 }
 
 /// Methods shared by all aggregate hash table modes.
-impl<AggrMode> AggregateHashTable<AggrMode> {
-    pub(super) fn new_with_filters(
+impl AggregateHashTable {
+    pub(in crate::aggregates) fn new(
         agg: &AggregateExec,
         partition: usize,
         output_schema: SchemaRef,
         batch_size: usize,
-        filters: Vec<Option<Arc<dyn PhysicalExpr>>>,
     ) -> Result<Self> {
         if batch_size == 0 {
             return internal_err!("AggregateHashTable requires config 
batch_size >= 1");
         }
 
+        // Infer the internal `AggregateTableMode` based on `AggregateExec`'s 
mode
+        //
+        // TODO(simplification): `AggregateMode` seems bloated for aggregate 
hash
+        // table semantics. Consider remove `AggregateMode` and only use 
`AggregateTableMode`
+        // after the refactor has finished.
+        //
+        // Issue: <https://github.com/apache/datafusion/pull/22729>

Review Comment:
   This seems like a PR (not an issue) -- is it the link you intended?
   
   Note I agree with the fact that AggregateMode seems overly complicated. I 
vaguely remember trying to remove it once but I can't find the PR now (maybe I 
never pushed it)
   
   



##########
datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs:
##########
@@ -31,19 +31,31 @@ use crate::aggregates::group_values::{GroupByMetrics, 
GroupValues, new_group_val
 use crate::aggregates::order::GroupOrdering;
 use crate::aggregates::row_hash::create_group_accumulator;
 use crate::aggregates::{
-    AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by,
+    AggregateExec, AggregateMode, PhysicalGroupBy, aggregate_expressions,
+    evaluate_group_by, group_id_array, max_duplicate_ordinal,
 };
 
-/// Marker for raw rows -> partial state aggregation.
+/// Semantic mode for the aggregate hash table.
+///
+/// See [`AggregateHashTable`] comment's 'Mode' section for details.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub(in crate::aggregates) enum AggregateTableMode {
+    /// Raw rows -> partial aggregate state rows.
+    Partial,
+    /// Partial aggregate state rows -> final aggregate value rows.
+    Final,
+    /// Partial aggregate state rows -> merged partial aggregate state rows.
+    PartialReduce,
+    /// Raw rows -> final aggregate value rows.
+    Single,
+}
+
+/// Marker for ordered raw rows -> partial state aggregation.
 pub(in crate::aggregates) struct PartialMarker;

Review Comment:
   do you plan to remove this PartialMarker for the ordered streams too? Maybe 
we could do the same simplification



##########
datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs:
##########
@@ -214,6 +321,239 @@ impl<AggrMode> AggregateHashTable<AggrMode> {
         state.batch_group_indices = Vec::new();
         self.state = AggregateHashTableState::Outputting(state);
     }
+
+    /// Aggregates one input batch according to this table's input semantics.
+    ///
+    /// `Partial` and `Single` update accumulator state from raw input rows.
+    /// `Final` and `PartialReduce` merge accumulator state emitted by an
+    /// earlier aggregate stage.
+    pub(in crate::aggregates) fn aggregate_batch(
+        &mut self,
+        batch: &RecordBatch,
+    ) -> Result<()> {
+        let evaluated_batch = self.evaluate_batch(batch)?;
+        let mode = self.mode;
+        let state = self.state.building_mut();
+
+        let timer = self.group_by_metrics.aggregation_time.timer();
+        for group_values in &evaluated_batch.grouping_set_args {
+            state
+                .group_values
+                .intern(group_values, &mut state.batch_group_indices)?;
+            let group_indices = &state.batch_group_indices;
+            let total_num_groups = state.group_values.len();
+
+            for (acc, values) in state
+                .accumulators
+                .iter_mut()
+                .zip(evaluated_batch.accumulator_args.iter())
+            {
+                match mode {

Review Comment:
   it probably doesn't matter as this is one match per column, but I wonder if 
we should do the match outside and the loop inside 🤔  



##########
datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs:
##########
@@ -214,6 +321,239 @@ impl<AggrMode> AggregateHashTable<AggrMode> {
         state.batch_group_indices = Vec::new();
         self.state = AggregateHashTableState::Outputting(state);
     }
+
+    /// Aggregates one input batch according to this table's input semantics.
+    ///
+    /// `Partial` and `Single` update accumulator state from raw input rows.
+    /// `Final` and `PartialReduce` merge accumulator state emitted by an
+    /// earlier aggregate stage.
+    pub(in crate::aggregates) fn aggregate_batch(
+        &mut self,
+        batch: &RecordBatch,
+    ) -> Result<()> {
+        let evaluated_batch = self.evaluate_batch(batch)?;
+        let mode = self.mode;
+        let state = self.state.building_mut();
+
+        let timer = self.group_by_metrics.aggregation_time.timer();
+        for group_values in &evaluated_batch.grouping_set_args {
+            state
+                .group_values
+                .intern(group_values, &mut state.batch_group_indices)?;
+            let group_indices = &state.batch_group_indices;
+            let total_num_groups = state.group_values.len();
+
+            for (acc, values) in state
+                .accumulators
+                .iter_mut()
+                .zip(evaluated_batch.accumulator_args.iter())
+            {
+                match mode {
+                    AggregateTableMode::Partial | AggregateTableMode::Single 
=> {
+                        acc.update_batch(values, group_indices, 
total_num_groups)?
+                    }
+                    AggregateTableMode::Final | 
AggregateTableMode::PartialReduce => {
+                        acc.merge_batch(values, group_indices, 
total_num_groups)?
+                    }
+                }
+            }
+        }
+        drop(timer);
+
+        Ok(())
+    }
+
+    /// Emits the next output batch according to this table's output semantics.
+    ///
+    /// `Partial` and `PartialReduce` emit accumulator states for downstream
+    /// aggregate stages. `Final` and `Single` evaluate final aggregate values
+    /// for the query result.
+    pub(in crate::aggregates) fn next_output_batch(
+        &mut self,
+    ) -> Result<Option<RecordBatch>> {
+        let output_schema = Arc::clone(&self.output_schema);
+        let batch_size = self.batch_size;
+        // Take ownership of the output state. `emit_next_materialized_batch`
+        // restores `self.state` to `OutputtingMaterialized` or `Done`.
+        match std::mem::replace(&mut self.state, 
AggregateHashTableState::Done) {
+            AggregateHashTableState::Outputting(state) => {
+                if state.group_values.is_empty() {
+                    return Ok(None);
+                }
+
+                let output = self.materialize_output(state, output_schema)?;
+                Ok(self.emit_next_materialized_batch(output, batch_size))
+            }
+            AggregateHashTableState::OutputtingMaterialized(output) => {
+                Ok(self.emit_next_materialized_batch(output, batch_size))
+            }
+            AggregateHashTableState::Done => Ok(None),
+            AggregateHashTableState::Building(_) => {
+                internal_err!("next_output_batch must be called in the 
outputting state")
+            }
+        }
+    }
+
+    fn materialize_output(
+        &self,
+        mut state: AggregateHashTableBuffer,
+        output_schema: SchemaRef,
+    ) -> Result<MaterializedAggregateOutput> {
+        // Final evaluation and partial state emission both consume accumulator
+        // state. Materialize all groups once, then slice on subsequent polls.
+        let emit_to = EmitTo::All;
+        let timer = self.group_by_metrics.emitting_time.timer();
+        let mut output = state.group_values.emit(emit_to)?;
+
+        for acc in state.accumulators.iter_mut() {
+            match self.mode {
+                AggregateTableMode::Partial | 
AggregateTableMode::PartialReduce => {
+                    output.extend(acc.state(emit_to)?)
+                }
+                AggregateTableMode::Final | AggregateTableMode::Single => {
+                    output.push(acc.evaluate(emit_to)?)
+                }
+            }
+        }
+        drop(timer);
+
+        let batch = RecordBatch::try_new(output_schema, output)?;
+        debug_assert!(batch.num_rows() > 0);
+        Ok(MaterializedAggregateOutput::new(batch))
+    }
+
+    fn emit_next_materialized_batch(
+        &mut self,
+        mut output: MaterializedAggregateOutput,
+        batch_size: usize,
+    ) -> Option<RecordBatch> {
+        let batch = output.next_batch(batch_size);
+        if output.is_exhausted() {
+            self.state = AggregateHashTableState::Done;
+        } else {
+            self.state = 
AggregateHashTableState::OutputtingMaterialized(output);
+        }
+        batch
+    }
+
+    pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> {
+        if matches!(
+            self.mode,
+            AggregateTableMode::Partial | AggregateTableMode::Single
+        ) {
+            self.init_empty_grouping_sets()?;
+        }
+        self.start_outputting();
+        Ok(())
+    }
+
+    /// Creates the required empty grouping-set rows when the input is empty.
+    ///
+    /// For example, this query must still produce one grand-total group even 
if
+    /// `t` has no rows:
+    ///
+    /// ```sql
+    /// SELECT COUNT(v)
+    /// FROM t
+    /// GROUP BY GROUPING SETS (());
+    /// ```
+    ///
+    /// The synthetic row is filtered out before accumulator update so 
aggregates
+    /// see the same state they would see for an empty input, rather than a 
real
+    /// null-valued row.
+    pub(super) fn init_empty_grouping_sets(&mut self) -> Result<()> {
+        let state = self.state.building_mut();
+        if !state.group_by.has_grouping_set() || 
!state.group_values.is_empty() {
+            return Ok(());
+        }
+
+        let max_ordinal = max_duplicate_ordinal(state.group_by.groups());
+        let mut ordinals: HashMap<&[bool], usize> = HashMap::new();
+        let group_schema = state.group_by.group_schema(&self.input_schema)?;
+        let n_expr = state.group_by.expr().len();
+        let mut any_interned = false;
+
+        for group in state.group_by.groups() {
+            let ordinal = {
+                let entry = ordinals.entry(group.as_slice()).or_insert(0);
+                let ordinal = *entry;
+                *entry += 1;
+                ordinal
+            };
+
+            if !group.iter().all(|&is_null| is_null) {
+                continue;
+            }
+
+            let mut cols: Vec<ArrayRef> = group_schema
+                .fields()
+                .iter()
+                .take(n_expr)
+                .map(|field| new_null_array(field.data_type(), 1))
+                .collect();
+            cols.push(group_id_array(group, ordinal, max_ordinal, 1)?);
+
+            state
+                .group_values
+                .intern(&cols, &mut state.batch_group_indices)?;
+            any_interned = true;
+        }
+
+        if any_interned {
+            let total_groups = state.group_values.len();
+            let false_filter = BooleanArray::from(vec![false]);
+            for acc in state.accumulators.iter_mut() {
+                let null_args = acc.null_arguments(&self.input_schema)?;
+                let values = EvaluatedAccumulatorArgs {
+                    arguments: null_args,
+                    filter: Some(Arc::new(false_filter.clone())),
+                };
+                acc.update_batch(&values, &[0], total_groups)?;
+            }
+        }
+
+        Ok(())
+    }
+}
+
+/// Hash table used only for converting raw input rows directly into partial
+/// aggregate state rows after partial aggregation has been skipped.
+pub(in crate::aggregates) struct PartialSkipHashTable {

Review Comment:
   i like that this has a named newtype wrapper



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to