QuakeWang commented on code in PR #340: URL: https://github.com/apache/paimon-rust/pull/340#discussion_r3396788809
########## crates/paimon/src/table/aggregator/numeric.rs: ########## @@ -0,0 +1,960 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Numeric aggregators: sum, product, min, max. +//! +//! `sum` operates on every integer / floating / Decimal numeric type. +//! `product` accepts the same numeric family except DECIMAL — basic mode does +//! not yet implement BigDecimal-style scale rebasing for Decimal product, so +//! Decimal columns are rejected at construction. Integer overflow on either +//! aggregator is reported as [`Error::DataInvalid`] so silent wrap cannot +//! produce misleading aggregated values. +//! +//! `min` / `max` extend to every ordered Paimon type: numerics, Decimal, +//! Date, Time, Timestamp, and Char/VarChar. Comparison is by native value +//! order (numeric for numbers, lexicographic for strings). Float NaN is +//! treated as greater than any other value, matching Java's +//! `Float.compare` / `Double.compare`. +//! +//! Reference: Java `FieldSumAgg`, `FieldProductAgg`, `FieldMinAgg`, +//! `FieldMaxAgg` under `org.apache.paimon.mergetree.compact.aggregate`. +//! +//! [`Error::DataInvalid`]: crate::Error::DataInvalid + +use std::sync::Arc; + +use arrow_array::builder::Decimal128Builder; +use arrow_array::{ + Array, ArrayRef, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, Time32MillisecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, +}; +use arrow_schema::TimeUnit; + +use super::{unsupported_type_error, FieldAggregator}; +use crate::spec::DataType; + +// --------------------------------------------------------------------------- +// Sum +// --------------------------------------------------------------------------- + +/// `sum` accumulator state, parameterized by the column's numeric kind. +#[derive(Debug)] +enum SumState { + I8(Option<i8>), + I16(Option<i16>), + I32(Option<i32>), + I64(Option<i64>), + F32(Option<f32>), + F64(Option<f64>), + Decimal128 { + precision: u8, + scale: i8, + acc: Option<i128>, + }, +} + +#[derive(Debug)] +pub(crate) struct SumAgg { + field_name: String, + state: SumState, +} + +impl SumAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result<Self> { + let state = match data_type { + DataType::TinyInt(_) => SumState::I8(None), + DataType::SmallInt(_) => SumState::I16(None), + DataType::Int(_) => SumState::I32(None), + DataType::BigInt(_) => SumState::I64(None), + DataType::Float(_) => SumState::F32(None), + DataType::Double(_) => SumState::F64(None), + DataType::Decimal(d) => SumState::Decimal128 { + precision: decimal_precision(d.precision(), field_name)?, + scale: decimal_scale(d.scale(), field_name)?, + acc: None, + }, + other => return Err(unsupported_type_error("sum", field_name, other)), + }; + Ok(Self { + field_name: field_name.to_string(), + state, + }) + } +} + +impl FieldAggregator for SumAgg { + fn name(&self) -> &'static str { + "sum" + } + + fn reset(&mut self) { + match &mut self.state { + SumState::I8(acc) => *acc = None, + SumState::I16(acc) => *acc = None, + SumState::I32(acc) => *acc = None, + SumState::I64(acc) => *acc = None, + SumState::F32(acc) => *acc = None, + SumState::F64(acc) => *acc = None, + SumState::Decimal128 { acc, .. } => *acc = None, + } + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + match &mut self.state { + SumState::I8(acc) => { + let v = downcast::<Int8Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I16(acc) => { + let v = downcast::<Int16Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I32(acc) => { + let v = downcast::<Int32Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I64(acc) => { + let v = downcast::<Int64Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::F32(acc) => { + let v = downcast::<Float32Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev + v)); + } + SumState::F64(acc) => { + let v = downcast::<Float64Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev + v)); + } + SumState::Decimal128 { acc, .. } => { + let v = downcast::<Decimal128Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + } + Ok(()) + } + + fn result(&self) -> crate::Result<ArrayRef> { + Ok(match &self.state { + SumState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + SumState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + SumState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + SumState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + SumState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + SumState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + SumState::Decimal128 { + precision, + scale, + acc, + } => decimal_array(*precision, *scale, *acc, "sum", &self.field_name)?, + }) + } +} + +// --------------------------------------------------------------------------- +// Product +// --------------------------------------------------------------------------- + +#[derive(Debug)] +enum ProductState { + I8(Option<i8>), + I16(Option<i16>), + I32(Option<i32>), + I64(Option<i64>), + F32(Option<f32>), + F64(Option<f64>), + // DECIMAL `product` is intentionally rejected at construction (see + // `ProductAgg::new`); add a variant here when the BigDecimal-style + // scale handling lands. +} + +#[derive(Debug)] +pub(crate) struct ProductAgg { + field_name: String, + state: ProductState, +} + +impl ProductAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result<Self> { + let state = match data_type { + DataType::TinyInt(_) => ProductState::I8(None), + DataType::SmallInt(_) => ProductState::I16(None), + DataType::Int(_) => ProductState::I32(None), + DataType::BigInt(_) => ProductState::I64(None), + DataType::Float(_) => ProductState::F32(None), + DataType::Double(_) => ProductState::F64(None), + // Decimal `product` would need BigDecimal-style scale rebasing + // (multiply raw i128, then divide by 10^scale, with precision + // checks). The basic mode does not implement that yet, so we + // reject DECIMAL columns explicitly rather than silently produce + // a scale-shifted result. + DataType::Decimal(_) => { + return Err(crate::Error::ConfigInvalid { + message: format!( + "Aggregate function 'product' on DECIMAL field '{field_name}' is not \ + supported in the basic mode; use a BIGINT/DOUBLE column or wait for a \ + follow-up commit that adds Decimal product semantics aligned with Java \ + BigDecimal" + ), + }); + } + other => return Err(unsupported_type_error("product", field_name, other)), + }; + Ok(Self { + field_name: field_name.to_string(), + state, + }) + } +} + +impl FieldAggregator for ProductAgg { + fn name(&self) -> &'static str { + "product" + } + + fn reset(&mut self) { + match &mut self.state { + ProductState::I8(acc) => *acc = None, + ProductState::I16(acc) => *acc = None, + ProductState::I32(acc) => *acc = None, + ProductState::I64(acc) => *acc = None, + ProductState::F32(acc) => *acc = None, + ProductState::F64(acc) => *acc = None, + } + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + match &mut self.state { + ProductState::I8(acc) => { + let v = downcast::<Int8Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I16(acc) => { + let v = downcast::<Int16Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I32(acc) => { + let v = downcast::<Int32Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I64(acc) => { + let v = downcast::<Int64Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::F32(acc) => { + let v = downcast::<Float32Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev * v)); + } + ProductState::F64(acc) => { + let v = downcast::<Float64Array>(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev * v)); + } + } + Ok(()) + } + + fn result(&self) -> crate::Result<ArrayRef> { + Ok(match &self.state { + ProductState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + ProductState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + ProductState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + ProductState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + ProductState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + ProductState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + }) + } +} + +// --------------------------------------------------------------------------- +// Min / Max — generic comparator-driven implementation +// --------------------------------------------------------------------------- + +/// `min` / `max` accumulator state. Each variant stores `Option<T>` where +/// `None` means "no non-null value seen yet for the current group". +#[derive(Debug)] +enum MinMaxState { + I8(Option<i8>), + I16(Option<i16>), + I32(Option<i32>), + I64(Option<i64>), + F32(Option<f32>), + F64(Option<f64>), + Decimal128 { + precision: u8, + scale: i8, + acc: Option<i128>, + }, + Date32(Option<i32>), + /// Paimon `TIME` is encoded as Arrow `Time32(Millisecond)` regardless of + /// declared precision, so a single accumulator variant suffices. + Time32Ms(Option<i32>), + Timestamp { + unit: TimeUnit, + acc: Option<i64>, + }, + Utf8(Option<String>), +} + +fn make_minmax_state( + field_name: &str, + data_type: &DataType, + op: &str, +) -> crate::Result<MinMaxState> { + Ok(match data_type { + DataType::TinyInt(_) => MinMaxState::I8(None), + DataType::SmallInt(_) => MinMaxState::I16(None), + DataType::Int(_) => MinMaxState::I32(None), + DataType::BigInt(_) => MinMaxState::I64(None), + DataType::Float(_) => MinMaxState::F32(None), + DataType::Double(_) => MinMaxState::F64(None), + DataType::Decimal(d) => MinMaxState::Decimal128 { + precision: decimal_precision(d.precision(), field_name)?, + scale: decimal_scale(d.scale(), field_name)?, + acc: None, + }, + DataType::Date(_) => MinMaxState::Date32(None), + DataType::Time(_) => MinMaxState::Time32Ms(None), + DataType::Timestamp(t) => MinMaxState::Timestamp { + unit: timestamp_time_unit(t.precision())?, + acc: None, + }, + DataType::Char(_) | DataType::VarChar(_) => MinMaxState::Utf8(None), + other => return Err(unsupported_type_error(op, field_name, other)), + }) +} + +fn timestamp_time_unit(precision: u32) -> crate::Result<TimeUnit> { + match precision { + 0..=3 => Ok(TimeUnit::Millisecond), + 4..=6 => Ok(TimeUnit::Microsecond), + 7..=9 => Ok(TimeUnit::Nanosecond), + other => Err(crate::Error::Unsupported { + message: format!("Unsupported TIMESTAMP precision {other} for min/max aggregator"), + }), + } +} + +fn agg_minmax( + state: &mut MinMaxState, + array: &dyn Array, + row_idx: usize, + field_name: &str, + keep_smaller: bool, +) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + macro_rules! update_primitive { + ($acc:expr, $ty:ty) => {{ + let v = downcast::<$ty>(array, field_name)?.value(row_idx); + *$acc = Some(match *$acc { + None => v, + Some(prev) => { + if (keep_smaller && v < prev) || (!keep_smaller && v > prev) { + v + } else { + prev + } + } + }); + }}; + } + macro_rules! update_float { + ($acc:expr, $ty:ty) => {{ + let v = downcast::<$ty>(array, field_name)?.value(row_idx); + // Match Java `Float.compare` / `Double.compare`, which order NaN + // greater than any other value (including +Infinity). Using + // `total_cmp` makes that ordering explicit and deterministic. + *$acc = Some(match *$acc { + None => v, + Some(prev) => { + let cmp = v.total_cmp(&prev); + let take_new = if keep_smaller { + cmp.is_lt() + } else { + cmp.is_gt() + }; + if take_new { + v + } else { + prev + } + } + }); + }}; + } + match state { + MinMaxState::I8(acc) => update_primitive!(acc, Int8Array), + MinMaxState::I16(acc) => update_primitive!(acc, Int16Array), + MinMaxState::I32(acc) => update_primitive!(acc, Int32Array), + MinMaxState::I64(acc) => update_primitive!(acc, Int64Array), + MinMaxState::F32(acc) => update_float!(acc, Float32Array), + MinMaxState::F64(acc) => update_float!(acc, Float64Array), + MinMaxState::Decimal128 { acc, .. } => update_primitive!(acc, Decimal128Array), + MinMaxState::Date32(acc) => update_primitive!(acc, Date32Array), + MinMaxState::Time32Ms(acc) => update_primitive!(acc, Time32MillisecondArray), + MinMaxState::Timestamp { unit, acc } => match unit { + TimeUnit::Millisecond => update_primitive!(acc, TimestampMillisecondArray), + TimeUnit::Microsecond => update_primitive!(acc, TimestampMicrosecondArray), + TimeUnit::Nanosecond => update_primitive!(acc, TimestampNanosecondArray), + other => { + return Err(crate::Error::DataInvalid { + message: format!( + "Timestamp with unit {other:?} not expected for field '{field_name}'" + ), + source: None, + }); + } + }, + MinMaxState::Utf8(acc) => { + let v = downcast::<StringArray>(array, field_name)?.value(row_idx); + *acc = Some(match acc.take() { + None => v.to_string(), + Some(prev) => { + let take_new = if keep_smaller { + v < prev.as_str() + } else { + v > prev.as_str() + }; + if take_new { + v.to_string() + } else { + prev + } + } + }); + } + } + Ok(()) +} + +fn minmax_result(state: &MinMaxState, agg_name: &str, field_name: &str) -> crate::Result<ArrayRef> { + Ok(match state { + MinMaxState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + MinMaxState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + MinMaxState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + MinMaxState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + MinMaxState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + MinMaxState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + MinMaxState::Decimal128 { + precision, + scale, + acc, + } => decimal_array(*precision, *scale, *acc, agg_name, field_name)?, + MinMaxState::Date32(acc) => Arc::new(Date32Array::from(vec![*acc])), + MinMaxState::Time32Ms(acc) => Arc::new(Time32MillisecondArray::from(vec![*acc])), + MinMaxState::Timestamp { unit, acc } => match unit { + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![*acc])), + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![*acc])), + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![*acc])), + other => { + return Err(crate::Error::DataInvalid { + message: format!( + "Timestamp with unit {other:?} not expected for field '{field_name}'" + ), + source: None, + }); + } + }, + MinMaxState::Utf8(acc) => Arc::new(StringArray::from(vec![acc.clone()])), + }) +} + +fn reset_minmax(state: &mut MinMaxState) { + match state { + MinMaxState::I8(acc) => *acc = None, + MinMaxState::I16(acc) => *acc = None, + MinMaxState::I32(acc) => *acc = None, + MinMaxState::I64(acc) => *acc = None, + MinMaxState::F32(acc) => *acc = None, + MinMaxState::F64(acc) => *acc = None, + MinMaxState::Decimal128 { acc, .. } => *acc = None, + MinMaxState::Date32(acc) => *acc = None, + MinMaxState::Time32Ms(acc) => *acc = None, + MinMaxState::Timestamp { acc, .. } => *acc = None, + MinMaxState::Utf8(acc) => *acc = None, + } +} + +#[derive(Debug)] +pub(crate) struct MinAgg { + field_name: String, + state: MinMaxState, +} + +impl MinAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result<Self> { + Ok(Self { + field_name: field_name.to_string(), + state: make_minmax_state(field_name, data_type, "min")?, + }) + } +} + +impl FieldAggregator for MinAgg { + fn name(&self) -> &'static str { + "min" + } + + fn reset(&mut self) { + reset_minmax(&mut self.state); + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + agg_minmax(&mut self.state, array, row_idx, &self.field_name, true) + } + + fn result(&self) -> crate::Result<ArrayRef> { + minmax_result(&self.state, "min", &self.field_name) + } +} + +#[derive(Debug)] +pub(crate) struct MaxAgg { + field_name: String, + state: MinMaxState, +} + +impl MaxAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result<Self> { + Ok(Self { + field_name: field_name.to_string(), + state: make_minmax_state(field_name, data_type, "max")?, + }) + } +} + +impl FieldAggregator for MaxAgg { + fn name(&self) -> &'static str { + "max" + } + + fn reset(&mut self) { + reset_minmax(&mut self.state); + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + agg_minmax(&mut self.state, array, row_idx, &self.field_name, false) + } + + fn result(&self) -> crate::Result<ArrayRef> { + minmax_result(&self.state, "max", &self.field_name) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn downcast<'a, T: Array + 'static>( + array: &'a dyn Array, + field_name: &str, +) -> crate::Result<&'a T> { + array + .as_any() + .downcast_ref::<T>() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "Aggregate column '{field_name}' received Arrow array of unexpected \ + type {:?}; expected {}", + array.data_type(), + std::any::type_name::<T>() + ), + source: None, + }) +} + +fn decimal_precision(precision: u32, field_name: &str) -> crate::Result<u8> { + u8::try_from(precision).map_err(|_| crate::Error::Unsupported { + message: format!( + "Decimal precision {precision} on field '{field_name}' exceeds u8 (Arrow limit)" + ), + }) +} + +fn decimal_scale(scale: u32, field_name: &str) -> crate::Result<i8> { + i8::try_from(scale as i32).map_err(|_| crate::Error::Unsupported { + message: format!( + "Decimal scale {scale} on field '{field_name}' is out of i8 range (Arrow limit)" + ), + }) +} + +fn overflow_error(agg_name: &str, field_name: &str) -> crate::Error { + crate::Error::DataInvalid { + message: format!("Aggregate function '{agg_name}' overflowed on field '{field_name}'"), + source: None, + } +} + +fn decimal_array( + precision: u8, + scale: i8, + value: Option<i128>, + agg_name: &str, + field_name: &str, +) -> crate::Result<ArrayRef> { + let mut builder = Decimal128Builder::with_capacity(1) Review Comment: `with_precision_and_scale` only sets DECIMAL metadata; it does not validate that the accumulated raw value fits the declared precision. For example, `DECIMAL(3,2)` with `9.99 + 0.01` can produce raw `1000` under `Decimal128(3,2)`. Please add precision validation or match Java `DecimalUtils.add/fromBigDecimal` semantics, with a boundary test. ########## crates/paimon/src/spec/aggregation.rs: ########## @@ -0,0 +1,686 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use crate::spec::{CoreOptions, DataField, DataType}; + +const MERGE_ENGINE_OPTION: &str = "merge-engine"; +const AGGREGATION_ENGINE: &str = "aggregation"; +const IGNORE_DELETE_OPTION: &str = "ignore-delete"; +const IGNORE_DELETE_SUFFIX: &str = ".ignore-delete"; +const AGGREGATION_REMOVE_RECORD_ON_DELETE_OPTION: &str = "aggregation.remove-record-on-delete"; +const FIELDS_DEFAULT_AGG_FUNCTION_OPTION: &str = "fields.default-aggregate-function"; +const FIELDS_PREFIX: &str = "fields."; +const AGG_FUNCTION_SUFFIX: &str = ".aggregate-function"; +const LIST_AGG_DELIMITER_SUFFIX: &str = ".list-agg-delimiter"; +const IGNORE_RETRACT_SUFFIX: &str = ".ignore-retract"; +const DISTINCT_SUFFIX: &str = ".distinct"; +const SEQUENCE_GROUP_SUFFIX: &str = ".sequence-group"; +const NESTED_KEY_SUFFIX: &str = ".nested-key"; +const COUNT_LIMIT_SUFFIX: &str = ".count-limit"; + +/// Minimal aggregation mode recognized by the current Rust implementation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum AggregationMode { + Basic, +} + +/// Aggregation-merge-engine option inspection and validation. +/// +/// The basic mode accepts only `merge-engine=aggregation` on a PK table with +/// the following option keys: +/// - `fields.default-aggregate-function` +/// - `fields.<col>.aggregate-function` +/// - `fields.<col>.list-agg-delimiter` +/// +/// All other aggregation-specific knobs (`ignore-retract`, `distinct`, +/// `nested-key`, `count-limit`, `aggregation.remove-record-on-delete`, +/// `sequence-group`, `ignore-delete`) are rejected. Retract rows +/// (DELETE / UPDATE_BEFORE) are rejected at runtime by the merge function. +#[derive(Debug, Clone, Copy)] +pub(crate) struct AggregationConfig<'a> { + options: &'a HashMap<String, String>, +} + +impl<'a> AggregationConfig<'a> { + pub(crate) fn new(options: &'a HashMap<String, String>) -> Self { + Self { options } + } + + pub(crate) fn is_enabled(&self) -> bool { + self.options + .get(MERGE_ENGINE_OPTION) + .is_some_and(|value| value.eq_ignore_ascii_case(AGGREGATION_ENGINE)) + } + + /// Validate options at CREATE TABLE time, using the schema's fields and + /// primary keys to reject typo'd column names, unknown aggregate + /// functions, and function/type pairs that the runtime would refuse. + /// + /// Java upstream rejects unknown columns and unknown function names in + /// `SchemaValidation.validateFieldsPrefix` + `validateMergeFunctionFactory`; + /// the function/type compatibility check is stricter than Java, which + /// defers it to `FieldAggregatorFactory#create` at runtime. Catching all + /// three at CREATE TABLE keeps invalid metadata from being persisted. + pub(crate) fn validate_create_mode( + &self, + primary_keys: &[String], + fields: &[DataField], + ) -> crate::Result<Option<AggregationMode>> { + let mode = match self.validated_mode(!primary_keys.is_empty()) { + Ok(mode) => mode, + Err(unsupported_options) => { + return Err(crate::Error::ConfigInvalid { + message: format!( + "merge-engine=aggregation only supports the basic mode in this build; unsupported options: {}", + unsupported_options.join(", ") + ), + }); + } + }; + if mode.is_some() { + self.validate_field_scoped_options(fields, primary_keys)?; + } + Ok(mode) + } + + /// Validate options at read/write runtime. + pub(crate) fn validate_runtime_mode( + &self, + has_primary_keys: bool, + table_name: &str, + ) -> crate::Result<Option<AggregationMode>> { + match self.validated_mode(has_primary_keys) { + Ok(mode) => Ok(mode), + Err(unsupported_options) => Err(crate::Error::Unsupported { + message: format!( + "Table '{table_name}' uses merge-engine=aggregation options not supported by this build: {}", + unsupported_options.join(", ") + ), + }), + } + } + + fn validated_mode( + &self, + has_primary_keys: bool, + ) -> std::result::Result<Option<AggregationMode>, Vec<String>> { + if !has_primary_keys || !self.is_enabled() { + return Ok(None); + } + + let unsupported_options = self.unsupported_option_keys(); + if !unsupported_options.is_empty() { + return Err(unsupported_options); + } + + Ok(Some(AggregationMode::Basic)) + } + + fn unsupported_option_keys(&self) -> Vec<String> { + let mut keys: Vec<String> = self + .options + .keys() + .filter(|key| is_unsupported_aggregation_option(key)) + .cloned() + .collect(); + keys.sort(); + keys + } + + /// Per-field aggregate function configured via `fields.<col>.aggregate-function`. + pub(crate) fn agg_function_for_field(&self, field_name: &str) -> Option<&str> { + let key = format!("{FIELDS_PREFIX}{field_name}{AGG_FUNCTION_SUFFIX}"); + self.options.get(&key).map(String::as_str) + } + + /// Default aggregate function from `fields.default-aggregate-function`. + pub(crate) fn default_agg_function(&self) -> Option<&str> { + self.options + .get(FIELDS_DEFAULT_AGG_FUNCTION_OPTION) + .map(String::as_str) + } + + /// Schema-aware checks run by [`validate_create_mode`] once the engine is + /// confirmed active. For every `fields.<col>.<known-suffix>` key + /// (currently `aggregate-function` and `list-agg-delimiter`): + /// * the `<col>` segment must name an existing schema field; this catches + /// typo'd column names that would otherwise silently fall back to the + /// default function / default delimiter at read time. + /// + /// For `aggregate-function` keys additionally: + /// * the function name must be one of the supported aggregators + /// * the function must accept the field's declared data type — except for + /// `sequence.field` columns (forced to `last_value` at runtime) and + /// primary-key columns (no aggregator; copied through), where the + /// configured function is ignored by the merge function's priority + /// order (Java `AggregateMergeFunction#getAggFuncName`), so only the + /// function name is validated. + /// + /// `fields.default-aggregate-function` only has its name validated; + /// per-column type compatibility for the default is deferred to runtime + /// because the default applies broadly across columns. + fn validate_field_scoped_options( + &self, + fields: &[DataField], + primary_keys: &[String], + ) -> crate::Result<()> { + // Same source as the read path: `sequence.field` parsed by CoreOptions. + let core_options = CoreOptions::new(self.options); + let sequence_fields = core_options.sequence_fields(); + for (key, value) in self.options { + let Some((col, kind)) = parse_field_scoped_option_key(key) else { + continue; + }; + let Some(field) = fields.iter().find(|f| f.name() == col) else { + let mut available: Vec<&str> = fields.iter().map(DataField::name).collect(); + available.sort(); + return Err(crate::Error::ConfigInvalid { + message: format!( + "Aggregation field '{col}' referenced by '{key}' is not declared in \ + the table schema; available columns: [{}]", + available.join(", ") + ), + }); + }; + if matches!(kind, FieldScopedOptionKind::AggregateFunction) { + let runtime_ignores_function = + sequence_fields.contains(&col) || primary_keys.iter().any(|pk| pk == col); + if runtime_ignores_function { + if !is_known_aggregator_name(value) { + return Err(crate::Error::ConfigInvalid { + message: format!( + "Unknown aggregate function '{value}' for field '{col}'; \ + {SUPPORTED_AGGREGATOR_NAMES_HINT}" + ), + }); + } + } else { + validate_aggregator_for_type(value, col, field.data_type())?; + } + } + } + + if let Some(default) = self + .options + .get(FIELDS_DEFAULT_AGG_FUNCTION_OPTION) + .map(String::as_str) + { + if !is_known_aggregator_name(default) { + return Err(crate::Error::ConfigInvalid { + message: format!( + "Unknown aggregate function '{default}' configured via \ + '{FIELDS_DEFAULT_AGG_FUNCTION_OPTION}'; {SUPPORTED_AGGREGATOR_NAMES_HINT}" + ), + }); + } + } + + Ok(()) + } +} + +/// Field-scoped option suffixes that schema-aware validation recognizes. +/// Each variant maps to a single `fields.<col>.<suffix>` key shape. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FieldScopedOptionKind { + AggregateFunction, + ListAggDelimiter, +} + +/// Parse the `<col>` segment and option kind out of a +/// `fields.<col>.<known-suffix>` key, or return `None` if `key` doesn't +/// match any known field-scoped option suffix. +fn parse_field_scoped_option_key(key: &str) -> Option<(&str, FieldScopedOptionKind)> { + let inner = key.strip_prefix(FIELDS_PREFIX)?; + for (suffix, kind) in [ + ( + AGG_FUNCTION_SUFFIX, + FieldScopedOptionKind::AggregateFunction, + ), + ( + LIST_AGG_DELIMITER_SUFFIX, + FieldScopedOptionKind::ListAggDelimiter, + ), + ] { + if let Some(col) = inner.strip_suffix(suffix) { + if col.is_empty() { + // `fields..<suffix>` is malformed; treat as "no match" so the + // caller surfaces a typo-style error elsewhere. + continue; + } + return Some((col, kind)); + } + } + None +} + +const SUPPORTED_AGGREGATOR_NAMES_HINT: &str = "supported: sum, product, min, max, last_value, \ + first_value, last_non_null_value, first_non_null_value, bool_and, bool_or, listagg"; + +/// Whether `name` matches one of the basic-mode aggregator identifiers. Must +/// stay in sync with the `match` arms in +/// `crate::table::aggregator::new_aggregator` — guarded by +/// `tests::validation_table_matches_constructors`. +pub(crate) fn is_known_aggregator_name(name: &str) -> bool { + matches!( + name, + "sum" + | "product" + | "min" + | "max" + | "last_value" + | "first_value" + | "last_non_null_value" + | "first_non_null_value" + | "bool_and" + | "bool_or" + | "listagg" + ) +} + +/// Mirror of the per-aggregator type checks in `crate::table::aggregator::*`. +/// `Ok(())` means the runtime `*Agg::new` constructor will accept the given +/// `(name, data_type)` pair. Must stay in sync — guarded by +/// `tests::validation_table_matches_constructors`. +pub(crate) fn validate_aggregator_for_type( + name: &str, + field_name: &str, + dt: &DataType, +) -> crate::Result<()> { + let ok = match name { + "sum" => matches!( + dt, + DataType::TinyInt(_) + | DataType::SmallInt(_) + | DataType::Int(_) + | DataType::BigInt(_) + | DataType::Float(_) + | DataType::Double(_) + | DataType::Decimal(_) + ), + "product" => matches!( + dt, + DataType::TinyInt(_) + | DataType::SmallInt(_) + | DataType::Int(_) + | DataType::BigInt(_) + | DataType::Float(_) + | DataType::Double(_) + ), + "min" | "max" => matches!( + dt, + DataType::TinyInt(_) + | DataType::SmallInt(_) + | DataType::Int(_) + | DataType::BigInt(_) + | DataType::Float(_) + | DataType::Double(_) + | DataType::Decimal(_) + | DataType::Date(_) + | DataType::Time(_) + | DataType::Timestamp(_) + | DataType::Char(_) + | DataType::VarChar(_) + ), + "bool_and" | "bool_or" => matches!(dt, DataType::Boolean(_)), + "listagg" => matches!(dt, DataType::Char(_) | DataType::VarChar(_)), Review Comment: This accepts `CHAR` and bounded `VARCHAR(n)` for `listagg`, but Java `FieldListaggAggFactory` only accepts unbounded `VarCharType` / `STRING`. This can create Rust metadata that Java rejects. Please restrict create-time validation and the runtime constructor, and add rejection tests for `CHAR` / bounded `VARCHAR`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
