Jefffrey commented on code in PR #17988:
URL: https://github.com/apache/datafusion/pull/17988#discussion_r2417056627


##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -0,0 +1,839 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::fmt::{Debug, Formatter};
+use std::mem::{size_of, size_of_val};
+use std::sync::Arc;
+
+use arrow::array::{
+    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, 
PrimitiveBuilder,
+};
+use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+use arrow::{
+    array::{Array, ArrayRef, AsArray},
+    datatypes::{
+        ArrowNativeType, DataType, Decimal128Type, Decimal256Type, 
Decimal32Type,
+        Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type,
+    },
+};
+
+use arrow::array::ArrowNativeTypeOp;
+
+use datafusion_common::{
+    internal_datafusion_err, internal_err, not_impl_datafusion_err, plan_err,
+    DataFusionError, HashSet, Result, ScalarValue,
+};
+use datafusion_expr::expr::{AggregateFunction, Sort};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, 
TypeSignature,
+    Volatility,
+};
+use datafusion_expr::{EmitTo, GroupsAccumulator};
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
+use datafusion_functions_aggregate_common::utils::Hashable;
+use datafusion_macros::user_doc;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+
+/// Precision multiplier for linear interpolation calculations.
+///
+/// This value of 1,000,000 was chosen to balance precision with overflow 
safety:
+/// - Provides 6 decimal places of precision for the fractional component
+/// - Small enough to avoid overflow when multiplied with typical numeric 
values
+/// - Sufficient precision for most statistical applications
+///
+/// The interpolation formula: `lower + (upper - lower) * fraction`
+/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / 
PRECISION`
+/// to avoid floating-point operations on integer types while maintaining 
precision.
+const INTERPOLATION_PRECISION: usize = 1_000_000;
+
+create_func!(PercentileCont, percentile_cont_udaf);
+
+/// Computes the exact percentile continuous of a set of numbers
+pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
+    let expr = order_by.expr.clone();
+    let args = vec![expr, percentile];
+
+    Expr::AggregateFunction(AggregateFunction::new_udf(
+        percentile_cont_udaf(),
+        args,
+        false,
+        None,
+        vec![order_by],
+        None,
+    ))
+}
+
+#[user_doc(
+    doc_section(label = "General Functions"),
+    description = "Returns the exact percentile of input values, interpolating 
between values if needed.",
+    syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY 
expression)",
+    sql_example = r#"```sql
+> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM 
table_name;
++----------------------------------------------------------+
+| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
++----------------------------------------------------------+
+| 45.5                                                     |
++----------------------------------------------------------+
+```
+
+An alternate syntax is also supported:
+```sql
+> SELECT percentile_cont(column_name, 0.75) FROM table_name;
++---------------------------------------+
+| percentile_cont(column_name, 0.75)    |
++---------------------------------------+
+| 45.5                                  |
++---------------------------------------+
+```"#,
+    standard_argument(name = "expression", prefix = "The"),
+    argument(
+        name = "percentile",
+        description = "Percentile to compute. Must be a float value between 0 
and 1 (inclusive)."
+    )
+)]
+/// PERCENTILE_CONT aggregate expression. This uses an exact calculation and 
stores all values
+/// in memory before computing the result. If an approximation is sufficient 
then
+/// APPROX_PERCENTILE_CONT provides a much more efficient solution.
+///
+/// If using the distinct variation, the memory usage will be similarly high 
if the
+/// cardinality is high as it stores all distinct values in memory before 
computing the
+/// result, but if cardinality is low then memory usage will also be lower.
+#[derive(PartialEq, Eq, Hash)]
+pub struct PercentileCont {
+    signature: Signature,

Review Comment:
   Should we consider adding a `quantile_cont` alias, which is the name DuckDB 
uses?



##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -0,0 +1,839 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::fmt::{Debug, Formatter};
+use std::mem::{size_of, size_of_val};
+use std::sync::Arc;
+
+use arrow::array::{
+    ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, 
PrimitiveBuilder,
+};
+use arrow::buffer::{OffsetBuffer, ScalarBuffer};
+use arrow::{
+    array::{Array, ArrayRef, AsArray},
+    datatypes::{
+        ArrowNativeType, DataType, Decimal128Type, Decimal256Type, 
Decimal32Type,
+        Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type,
+    },
+};
+
+use arrow::array::ArrowNativeTypeOp;
+
+use datafusion_common::{
+    internal_datafusion_err, internal_err, not_impl_datafusion_err, plan_err,
+    DataFusionError, HashSet, Result, ScalarValue,
+};
+use datafusion_expr::expr::{AggregateFunction, Sort};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, 
TypeSignature,
+    Volatility,
+};
+use datafusion_expr::{EmitTo, GroupsAccumulator};
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
+use datafusion_functions_aggregate_common::utils::Hashable;
+use datafusion_macros::user_doc;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+
+/// Precision multiplier for linear interpolation calculations.
+///
+/// This value of 1,000,000 was chosen to balance precision with overflow 
safety:
+/// - Provides 6 decimal places of precision for the fractional component
+/// - Small enough to avoid overflow when multiplied with typical numeric 
values
+/// - Sufficient precision for most statistical applications
+///
+/// The interpolation formula: `lower + (upper - lower) * fraction`
+/// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / 
PRECISION`
+/// to avoid floating-point operations on integer types while maintaining 
precision.
+const INTERPOLATION_PRECISION: usize = 1_000_000;
+
+create_func!(PercentileCont, percentile_cont_udaf);
+
+/// Computes the exact percentile continuous of a set of numbers
+pub fn percentile_cont(order_by: Sort, percentile: Expr) -> Expr {
+    let expr = order_by.expr.clone();
+    let args = vec![expr, percentile];
+
+    Expr::AggregateFunction(AggregateFunction::new_udf(
+        percentile_cont_udaf(),
+        args,
+        false,
+        None,
+        vec![order_by],
+        None,
+    ))
+}
+
+#[user_doc(
+    doc_section(label = "General Functions"),
+    description = "Returns the exact percentile of input values, interpolating 
between values if needed.",
+    syntax_example = "percentile_cont(percentile) WITHIN GROUP (ORDER BY 
expression)",
+    sql_example = r#"```sql
+> SELECT percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM 
table_name;
++----------------------------------------------------------+
+| percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
++----------------------------------------------------------+
+| 45.5                                                     |
++----------------------------------------------------------+
+```
+
+An alternate syntax is also supported:
+```sql
+> SELECT percentile_cont(column_name, 0.75) FROM table_name;
++---------------------------------------+
+| percentile_cont(column_name, 0.75)    |
++---------------------------------------+
+| 45.5                                  |
++---------------------------------------+
+```"#,
+    standard_argument(name = "expression", prefix = "The"),
+    argument(
+        name = "percentile",
+        description = "Percentile to compute. Must be a float value between 0 
and 1 (inclusive)."
+    )
+)]
+/// PERCENTILE_CONT aggregate expression. This uses an exact calculation and 
stores all values
+/// in memory before computing the result. If an approximation is sufficient 
then
+/// APPROX_PERCENTILE_CONT provides a much more efficient solution.
+///
+/// If using the distinct variation, the memory usage will be similarly high 
if the
+/// cardinality is high as it stores all distinct values in memory before 
computing the
+/// result, but if cardinality is low then memory usage will also be lower.
+#[derive(PartialEq, Eq, Hash)]
+pub struct PercentileCont {
+    signature: Signature,
+}
+
+impl Debug for PercentileCont {
+    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+        f.debug_struct("PercentileCont")
+            .field("name", &self.name())
+            .field("signature", &self.signature)
+            .finish()
+    }
+}
+
+impl Default for PercentileCont {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl PercentileCont {
+    pub fn new() -> Self {
+        let mut variants = Vec::with_capacity(NUMERICS.len());
+        // Accept any numeric value paired with a float64 percentile
+        for num in NUMERICS {
+            variants.push(TypeSignature::Exact(vec![num.clone(), 
DataType::Float64]));
+        }
+        Self {
+            signature: Signature::one_of(variants, Volatility::Immutable),
+        }
+    }
+
+    fn create_accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        let percentile = validate_percentile(&args.exprs[1])?;
+
+        let is_descending = args
+            .order_bys
+            .first()
+            .map(|sort_expr| sort_expr.options.descending)
+            .unwrap_or(false);
+
+        let percentile = if is_descending {
+            1.0 - percentile
+        } else {
+            percentile
+        };
+
+        macro_rules! helper {
+            ($t:ty, $dt:expr) => {
+                if args.is_distinct {
+                    Ok(Box::new(DistinctPercentileContAccumulator::<$t> {
+                        data_type: $dt.clone(),
+                        distinct_values: HashSet::new(),
+                        percentile,
+                    }))
+                } else {
+                    Ok(Box::new(PercentileContAccumulator::<$t> {
+                        data_type: $dt.clone(),
+                        all_values: vec![],
+                        percentile,
+                    }))
+                }
+            };
+        }
+
+        let input_dt = args.exprs[0].data_type(args.schema)?;
+        match input_dt {
+            // For integer types, use Float64 internally since percentile_cont 
returns Float64
+            DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64 => helper!(Float64Type, DataType::Float64),
+            DataType::Float16 => helper!(Float16Type, input_dt),
+            DataType::Float32 => helper!(Float32Type, input_dt),
+            DataType::Float64 => helper!(Float64Type, input_dt),
+            DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt),
+            DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt),
+            DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt),
+            DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt),
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "PercentileContAccumulator not supported for {} with {}",
+                args.name, input_dt,
+            ))),
+        }
+    }
+}
+
+fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
+    use arrow::array::RecordBatch;
+    use arrow::datatypes::Schema;
+    use datafusion_expr::ColumnarValue;
+
+    let empty_schema = Arc::new(Schema::empty());
+    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
+    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
+        Ok(s)
+    } else {
+        internal_err!("Didn't expect ColumnarValue::Array")
+    }
+}
+
+fn validate_percentile(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
+    let percentile = match get_scalar_value(expr)
+        .map_err(|_| not_impl_datafusion_err!("Percentile value for 
'PERCENTILE_CONT' must be a literal"))? {
+        ScalarValue::Float32(Some(value)) => {
+            value as f64
+        }
+        ScalarValue::Float64(Some(value)) => {
+            value
+        }
+        sv => {
+            return plan_err!(
+                "Percentile value for 'PERCENTILE_CONT' must be Float32 or 
Float64 literal (got data type {})",
+                sv.data_type()
+            )
+        }
+    };
+
+    // Ensure the percentile is between 0 and 1.
+    if !(0.0..=1.0).contains(&percentile) {
+        return plan_err!(
+            "Percentile value must be between 0.0 and 1.0 inclusive, 
{percentile} is invalid"
+        );
+    }
+    Ok(percentile)
+}
+
+impl AggregateUDFImpl for PercentileCont {
+    fn as_any(&self) -> &dyn std::any::Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "percentile_cont"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        if !arg_types[0].is_numeric() {
+            return plan_err!("percentile_cont requires numeric input types");
+        }
+        // PERCENTILE_CONT performs linear interpolation and should return a 
float type
+        // For integer inputs, return Float64 (matching PostgreSQL/DuckDB 
behavior)
+        // For float inputs, preserve the float type
+        match &arg_types[0] {
+            DataType::Float16 | DataType::Float32 | DataType::Float64 => {
+                Ok(arg_types[0].clone())
+            }
+            DataType::Decimal32(_, _)
+            | DataType::Decimal64(_, _)
+            | DataType::Decimal128(_, _)
+            | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()),
+            DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64 => Ok(DataType::Float64),
+            // Shouldn't happen due to signature check, but just in case
+            dt => plan_err!(
+                "percentile_cont does not support input type {}, must be 
numeric",
+                dt
+            ),
+        }
+    }
+
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
+        //Intermediate state is a list of the elements we have collected so far
+        let input_type = args.input_fields[0].data_type().clone();
+        // For integer types, we store as Float64 internally
+        let storage_type = match &input_type {
+            DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64 => DataType::Float64,
+            _ => input_type,
+        };
+
+        let field = Field::new_list_field(storage_type, true);
+        let state_name = if args.is_distinct {
+            "distinct_percentile_cont"
+        } else {
+            "percentile_cont"
+        };
+
+        Ok(vec![Field::new(
+            format_state_name(args.name, state_name),
+            DataType::List(Arc::new(field)),
+            true,
+        )
+        .into()])
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        self.create_accumulator(acc_args)
+    }
+
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        !args.is_distinct
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        let num_args = args.exprs.len();
+        if num_args != 2 {
+            return internal_err!(
+                "percentile_cont should have 2 args, but found num args:{}",
+                args.exprs.len()
+            );
+        }
+
+        let percentile = validate_percentile(&args.exprs[1])?;

Review Comment:
   Bit of the code here feels duplicated with `create_accumulator()`, not sure 
how much it matters though



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to