neilconway commented on code in PR #21058:
URL: https://github.com/apache/datafusion/pull/21058#discussion_r3033661965


##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########
@@ -1942,25 +1945,74 @@ fn evaluate_optional(
         .collect()
 }
 
-fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
-    if group.len() > 64 {
+/// Builds the internal `__grouping_id` array for a single grouping set.
+///
+/// The returned array packs two values into a single integer:
+///
+/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask.  A `1` bit
+///   at position `i` means that the `i`-th grouping column (counting from the
+///   least significant bit, i.e. the *last* column in the `group` slice) is
+///   `NULL` for this grouping set.
+/// - High bits (positions n and above): the duplicate `ordinal`, which
+///   distinguishes multiple occurrences of the same grouping-set pattern.  The
+///   ordinal is `0` for the first occurrence, `1` for the second, and so on.
+///
+/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
+/// UInt64` that can represent both parts.  It matches the type returned by
+/// [`Aggregate::grouping_id_type`].
+fn group_id_array(
+    group: &[bool],
+    ordinal: usize,
+    max_ordinal: usize,
+    batch: &RecordBatch,
+) -> Result<ArrayRef> {
+    let n = group.len();
+    if n > 64 {
         return not_impl_err!(
             "Grouping sets with more than 64 columns are not supported"
         );
     }
-    let group_id = group.iter().fold(0u64, |acc, &is_null| {
+    let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as 
usize;
+    let total_bits = n + ordinal_bits;
+    if total_bits > 64 {
+        return not_impl_err!(
+            "Grouping sets with {n} columns and a maximum duplicate ordinal of 
\
+             {max_ordinal} require {total_bits} bits, which exceeds 64"
+        );
+    }
+    let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
         (acc << 1) | if is_null { 1 } else { 0 }
     });
+    let full_id = semantic_id | ((ordinal as u64) << n);
     let num_rows = batch.num_rows();
-    if group.len() <= 8 {
-        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
-    } else if group.len() <= 16 {
-        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
-    } else if group.len() <= 32 {
-        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
+    if total_bits <= 8 {
+        Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
+    } else if total_bits <= 16 {
+        Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
+    } else if total_bits <= 32 {
+        Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
     } else {
-        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
+        Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
+    }
+}
+
+/// Returns the highest duplicate ordinal across all grouping sets.
+///
+/// The ordinal counts how many times a grouping-set pattern has already
+/// appeared before the current occurrence.  If the same `Vec<bool>` appears
+/// three times the ordinals are 0, 1, 2 and this function returns 2.
+/// Returns 0 when no grouping set is duplicated.
+fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
+    let mut counts: HashMap<&Vec<bool>, usize> = HashMap::new();
+    for group in groups {
+        *counts.entry(group).or_insert(0) += 1;
     }
+    counts
+        .values()
+        .copied()
+        .max()
+        .unwrap_or(1)
+        .saturating_sub(1)

Review Comment:
   `counts.into_values().max().unwrap_or(0).saturating_sub(1)`



##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########
@@ -1942,25 +1945,74 @@ fn evaluate_optional(
         .collect()
 }
 
-fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
-    if group.len() > 64 {
+/// Builds the internal `__grouping_id` array for a single grouping set.
+///
+/// The returned array packs two values into a single integer:
+///
+/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask.  A `1` bit
+///   at position `i` means that the `i`-th grouping column (counting from the
+///   least significant bit, i.e. the *last* column in the `group` slice) is
+///   `NULL` for this grouping set.
+/// - High bits (positions n and above): the duplicate `ordinal`, which
+///   distinguishes multiple occurrences of the same grouping-set pattern.  The
+///   ordinal is `0` for the first occurrence, `1` for the second, and so on.
+///
+/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
+/// UInt64` that can represent both parts.  It matches the type returned by
+/// [`Aggregate::grouping_id_type`].
+fn group_id_array(
+    group: &[bool],
+    ordinal: usize,
+    max_ordinal: usize,
+    batch: &RecordBatch,
+) -> Result<ArrayRef> {
+    let n = group.len();
+    if n > 64 {
         return not_impl_err!(
             "Grouping sets with more than 64 columns are not supported"
         );
     }
-    let group_id = group.iter().fold(0u64, |acc, &is_null| {
+    let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as 
usize;
+    let total_bits = n + ordinal_bits;
+    if total_bits > 64 {
+        return not_impl_err!(
+            "Grouping sets with {n} columns and a maximum duplicate ordinal of 
\
+             {max_ordinal} require {total_bits} bits, which exceeds 64"
+        );
+    }
+    let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
         (acc << 1) | if is_null { 1 } else { 0 }
     });
+    let full_id = semantic_id | ((ordinal as u64) << n);
     let num_rows = batch.num_rows();
-    if group.len() <= 8 {
-        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
-    } else if group.len() <= 16 {
-        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
-    } else if group.len() <= 32 {
-        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
+    if total_bits <= 8 {
+        Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
+    } else if total_bits <= 16 {
+        Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
+    } else if total_bits <= 32 {
+        Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
     } else {
-        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
+        Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
+    }
+}
+
+/// Returns the highest duplicate ordinal across all grouping sets.
+///
+/// The ordinal counts how many times a grouping-set pattern has already
+/// appeared before the current occurrence.  If the same `Vec<bool>` appears

Review Comment:
   "the current occurrence" doesn't seem to make sense when you look at the 
function itself, it is talking about the call-site.



##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########
@@ -1942,25 +1945,74 @@ fn evaluate_optional(
         .collect()
 }
 
-fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
-    if group.len() > 64 {
+/// Builds the internal `__grouping_id` array for a single grouping set.
+///
+/// The returned array packs two values into a single integer:
+///
+/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask.  A `1` bit
+///   at position `i` means that the `i`-th grouping column (counting from the
+///   least significant bit, i.e. the *last* column in the `group` slice) is
+///   `NULL` for this grouping set.
+/// - High bits (positions n and above): the duplicate `ordinal`, which
+///   distinguishes multiple occurrences of the same grouping-set pattern.  The
+///   ordinal is `0` for the first occurrence, `1` for the second, and so on.
+///
+/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
+/// UInt64` that can represent both parts.  It matches the type returned by
+/// [`Aggregate::grouping_id_type`].
+fn group_id_array(
+    group: &[bool],
+    ordinal: usize,
+    max_ordinal: usize,
+    batch: &RecordBatch,
+) -> Result<ArrayRef> {
+    let n = group.len();
+    if n > 64 {
         return not_impl_err!(
             "Grouping sets with more than 64 columns are not supported"
         );
     }
-    let group_id = group.iter().fold(0u64, |acc, &is_null| {
+    let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as 
usize;
+    let total_bits = n + ordinal_bits;
+    if total_bits > 64 {
+        return not_impl_err!(
+            "Grouping sets with {n} columns and a maximum duplicate ordinal of 
\
+             {max_ordinal} require {total_bits} bits, which exceeds 64"
+        );
+    }
+    let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
         (acc << 1) | if is_null { 1 } else { 0 }
     });
+    let full_id = semantic_id | ((ordinal as u64) << n);
     let num_rows = batch.num_rows();
-    if group.len() <= 8 {
-        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
-    } else if group.len() <= 16 {
-        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
-    } else if group.len() <= 32 {
-        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
+    if total_bits <= 8 {
+        Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
+    } else if total_bits <= 16 {
+        Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
+    } else if total_bits <= 32 {
+        Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
     } else {
-        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
+        Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
+    }
+}
+
+/// Returns the highest duplicate ordinal across all grouping sets.
+///
+/// The ordinal counts how many times a grouping-set pattern has already
+/// appeared before the current occurrence.  If the same `Vec<bool>` appears
+/// three times the ordinals are 0, 1, 2 and this function returns 2.
+/// Returns 0 when no grouping set is duplicated.
+fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
+    let mut counts: HashMap<&Vec<bool>, usize> = HashMap::new();

Review Comment:
   `HashMap<&[bool], usize>`



##########
datafusion/optimizer/src/analyzer/resolve_grouping_function.rs:
##########
@@ -204,17 +204,26 @@ fn grouping_function_on_id(
             Expr::Literal(ScalarValue::from(value as u64), None)
         }
     };
-
     let grouping_id_column = 
Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
-    // The grouping call is exactly our internal grouping id
+    // The grouping call is exactly our internal grouping id — mask the ordinal
+    // bits (above position `n`) so only the semantic bitmask is visible.
+    let n = group_by_expr_count;
+    // (1 << n) - 1 masks the low n bits.  Use saturating arithmetic to handle 
n == 0.

Review Comment:
   Although I don't think `wrapping_sub` is necessary anyway, `n == 0` -> `1 << 
0` -> `1`, so it won't underflow.



##########
datafusion/optimizer/src/analyzer/resolve_grouping_function.rs:
##########
@@ -204,17 +204,26 @@ fn grouping_function_on_id(
             Expr::Literal(ScalarValue::from(value as u64), None)
         }
     };
-
     let grouping_id_column = 
Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
-    // The grouping call is exactly our internal grouping id
+    // The grouping call is exactly our internal grouping id — mask the ordinal
+    // bits (above position `n`) so only the semantic bitmask is visible.
+    let n = group_by_expr_count;
+    // (1 << n) - 1 masks the low n bits.  Use saturating arithmetic to handle 
n == 0.
+    let semantic_mask: u64 = if n >= 64 {
+        u64::MAX
+    } else {
+        (1u64 << n).wrapping_sub(1)
+    };
+    let masked_id =

Review Comment:
   Don't compute this outside the `if` in which it is used.



##########
datafusion/optimizer/src/analyzer/resolve_grouping_function.rs:
##########
@@ -204,17 +204,26 @@ fn grouping_function_on_id(
             Expr::Literal(ScalarValue::from(value as u64), None)
         }
     };
-
     let grouping_id_column = 
Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
-    // The grouping call is exactly our internal grouping id
+    // The grouping call is exactly our internal grouping id — mask the ordinal
+    // bits (above position `n`) so only the semantic bitmask is visible.
+    let n = group_by_expr_count;
+    // (1 << n) - 1 masks the low n bits.  Use saturating arithmetic to handle 
n == 0.

Review Comment:
   The code doesn't use saturating arithmetic?



##########
datafusion/sql/src/unparser/utils.rs:
##########
@@ -247,6 +247,11 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) 
-> Result<Option<&'a E
                     )
                 }
                 Ordering::Greater => {
+                    if index < grouping_expr.len() + 1 {

Review Comment:
   This seems like dead code? i.e., index > grouping_expr.len -> index >= 
(grouping_expr.len + 1)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to