This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c40bc7c2 feat: Supports Stddev (#348)
c40bc7c2 is described below

commit c40bc7c26f13773248921646ac4f755025e4823f
Author: Huaxin Gao <[email protected]>
AuthorDate: Mon May 6 17:20:45 2024 -0700

    feat: Supports Stddev (#348)
    
    * feat: Supports Stddev
    
    * fix fmt
    
    * update q39a.sql.out
    
    * address comments
    
    * disable q93a and q93b for now
    
    * address comments
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
---
 core/src/execution/datafusion/expressions/mod.rs   |   1 +
 .../src/execution/datafusion/expressions/stddev.rs | 179 +++++++++++++++++++++
 .../execution/datafusion/expressions/variance.rs   |   2 -
 core/src/execution/datafusion/planner.rs           |  25 +++
 core/src/execution/proto/expr.proto                |   8 +
 docs/source/user-guide/expressions.md              |   2 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  42 ++++-
 .../apache/comet/exec/CometAggregateSuite.scala    |  43 +++++
 .../apache/spark/sql/CometTPCDSQuerySuite.scala    |   9 +-
 9 files changed, 306 insertions(+), 5 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 78763fc2..10cac169 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -29,6 +29,7 @@ pub mod avg_decimal;
 pub mod bloom_filter_might_contain;
 pub mod covariance;
 pub mod stats;
+pub mod stddev;
 pub mod strings;
 pub mod subquery;
 pub mod sum_decimal;
diff --git a/core/src/execution/datafusion/expressions/stddev.rs 
b/core/src/execution/datafusion/expressions/stddev.rs
new file mode 100644
index 00000000..bbddf9aa
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/stddev.rs
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::{
+    stats::StatsType, utils::down_cast_any_ref, variance::VarianceAccumulator,
+};
+use arrow::{
+    array::ArrayRef,
+    datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{internal_err, Result, ScalarValue};
+use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, 
PhysicalExpr};
+
+/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation. 
The reason
+/// we have our own implementation is that DataFusion has UInt64 for 
state_field `count`,
+/// while Spark has Double for count. Also we have added 
`null_on_divide_by_zero`
+/// to be consistent with Spark's implementation.
+#[derive(Debug)]
+pub struct Stddev {
+    name: String,
+    expr: Arc<dyn PhysicalExpr>,
+    stats_type: StatsType,
+    null_on_divide_by_zero: bool,
+}
+
+impl Stddev {
+    /// Create a new STDDEV aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+        stats_type: StatsType,
+        null_on_divide_by_zero: bool,
+    ) -> Self {
+        // the result of stddev just support FLOAT64.
+        assert!(matches!(data_type, DataType::Float64));
+        Self {
+            name: name.into(),
+            expr,
+            stats_type,
+            null_on_divide_by_zero,
+        }
+    }
+}
+
+impl AggregateExpr for Stddev {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::Float64, true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(StddevAccumulator::try_new(
+            self.stats_type,
+            self.null_on_divide_by_zero,
+        )?))
+    }
+
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(StddevAccumulator::try_new(
+            self.stats_type,
+            self.null_on_divide_by_zero,
+        )?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(
+                format_state_name(&self.name, "count"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(format_state_name(&self.name, "m2"), DataType::Float64, 
true),
+        ])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for Stddev {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name
+                    && self.expr.eq(&x.expr)
+                    && self.null_on_divide_by_zero == x.null_on_divide_by_zero
+                    && self.stats_type == x.stats_type
+            })
+            .unwrap_or(false)
+    }
+}
+
+/// An accumulator to compute the standard deviation
+#[derive(Debug)]
+pub struct StddevAccumulator {
+    variance: VarianceAccumulator,
+}
+
+impl StddevAccumulator {
+    /// Creates a new `StddevAccumulator`
+    pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> 
Result<Self> {
+        Ok(Self {
+            variance: VarianceAccumulator::try_new(s_type, 
null_on_divide_by_zero)?,
+        })
+    }
+
+    pub fn get_m2(&self) -> f64 {
+        self.variance.get_m2()
+    }
+}
+
+impl Accumulator for StddevAccumulator {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.variance.get_count()),
+            ScalarValue::from(self.variance.get_mean()),
+            ScalarValue::from(self.variance.get_m2()),
+        ])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.variance.update_batch(values)
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.variance.retract_batch(values)
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.variance.merge_batch(states)
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let variance = self.variance.evaluate()?;
+        match variance {
+            ScalarValue::Float64(Some(e)) => 
Ok(ScalarValue::Float64(Some(e.sqrt()))),
+            ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)),
+            _ => internal_err!("Variance should be f64"),
+        }
+    }
+
+    fn size(&self) -> usize {
+        std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) 
+ self.variance.size()
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/variance.rs 
b/core/src/execution/datafusion/expressions/variance.rs
index 6aae01ed..f996c13d 100644
--- a/core/src/execution/datafusion/expressions/variance.rs
+++ b/core/src/execution/datafusion/expressions/variance.rs
@@ -15,8 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines physical expressions that can evaluated at runtime during query 
execution
-
 use std::{any::Any, sync::Arc};
 
 use crate::execution::datafusion::expressions::{stats::StatsType, 
utils::down_cast_any_ref};
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 72174790..6a050eb8 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -71,6 +71,7 @@ use crate::{
                 if_expr::IfExpr,
                 scalar_funcs::create_comet_physical_fun,
                 stats::StatsType,
+                stddev::Stddev,
                 strings::{Contains, EndsWith, Like, StartsWith, 
StringSpaceExec, SubstringExec},
                 subquery::Subquery,
                 sum_decimal::SumDecimal,
@@ -1260,6 +1261,30 @@ impl PhysicalPlanner {
                     ))),
                 }
             }
+            AggExprStruct::Stddev(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema.clone())?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                match expr.stats_type {
+                    0 => Ok(Arc::new(Stddev::new(
+                        child,
+                        "stddev",
+                        datatype,
+                        StatsType::Sample,
+                        expr.null_on_divide_by_zero,
+                    ))),
+                    1 => Ok(Arc::new(Stddev::new(
+                        child,
+                        "stddev_pop",
+                        datatype,
+                        StatsType::Population,
+                        expr.null_on_divide_by_zero,
+                    ))),
+                    stats_type => Err(ExecutionError::GeneralError(format!(
+                        "Unknown StatisticsType {:?} for stddev",
+                        stats_type
+                    ))),
+                }
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index 042a981f..ee3de865 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -95,6 +95,7 @@ message AggExpr {
     CovSample covSample = 12;
     CovPopulation covPopulation = 13;
     Variance variance = 14;
+    Stddev stddev = 15;
   }
 }
 
@@ -178,6 +179,13 @@ message Variance {
   StatisticsType stats_type = 4;
 }
 
+message Stddev {
+  Expr child = 1;
+  bool null_on_divide_by_zero = 2;
+  DataType datatype = 3;
+  StatisticsType stats_type = 4;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/docs/source/user-guide/expressions.md 
b/docs/source/user-guide/expressions.md
index f67a4ead..38c86c72 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -107,3 +107,5 @@ The following Spark expressions are currently available:
   - CovSample
   - VariancePop
   - VarianceSamp
+  - StddevPop
+  - StddevSamp
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index e77adc9b..1e8877c8 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, 
VarianceSamp}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
@@ -506,6 +506,46 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
           withInfo(aggExpr, child)
           None
         }
+      case std @ StddevSamp(child, nullOnDivideByZero) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(std.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val stdBuilder = ExprOuterClass.Stddev.newBuilder()
+          stdBuilder.setChild(childExpr.get)
+          stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+          stdBuilder.setDatatype(dataType.get)
+          stdBuilder.setStatsTypeValue(0)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setStddev(stdBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child)
+          None
+        }
+      case std @ StddevPop(child, nullOnDivideByZero) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(std.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val stdBuilder = ExprOuterClass.Stddev.newBuilder()
+          stdBuilder.setChild(childExpr.get)
+          stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+          stdBuilder.setDatatype(dataType.get)
+          stdBuilder.setStatsTypeValue(1)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setStddev(stdBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child)
+          None
+        }
       case fn =>
         val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
         emitWarning(msg)
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 64c031ee..310a24ee 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1169,6 +1169,49 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("stddev_pop and stddev_samp") {
+    withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { cometColumnShuffleEnabled =>
+        withSQLConf(
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
cometColumnShuffleEnabled.toString) {
+          Seq(true, false).foreach { dictionary =>
+            withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+              Seq(true, false).foreach { nullOnDivideByZero =>
+                withSQLConf(
+                  "spark.sql.legacy.statisticalAggregate" -> 
nullOnDivideByZero.toString) {
+                  val table = "test"
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int, 
col4 float, " +
+                      "col5 double, col6 int) using parquet")
+                    sql(s"insert into $table values(1, null, null, 1.1, 2.2, 
1), " +
+                      "(2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4, 
2)")
+                    val expectedNumOfCometAggregates = 2
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      "SELECT stddev_samp(col1), stddev_samp(col2), 
stddev_samp(col3), " +
+                        "stddev_samp(col4), stddev_samp(col5) FROM test",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      "SELECT stddev_pop(col1), stddev_pop(col2), 
stddev_pop(col3), " +
+                        "stddev_pop(col4), stddev_pop(col5) FROM test",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerAndNumOfAggregates(
+                      "SELECT stddev_samp(col1), stddev_samp(col2), 
stddev_samp(col3), " +
+                        "stddev_samp(col4), stddev_samp(col5) FROM test GROUP 
BY col6",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      "SELECT stddev_pop(col1), stddev_pop(col2), 
stddev_pop(col3), " +
+                        "stddev_pop(col4), stddev_pop(col5) FROM test GROUP BY 
col6",
+                      expectedNumOfCometAggregates)
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
index cdbd7194..3342d750 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
@@ -73,8 +73,13 @@ class CometTPCDSQuerySuite
         "q36",
         "q37",
         "q38",
-        "q39a",
-        "q39b",
+        // TODO: https://github.com/apache/datafusion-comet/issues/392
+        //  comment out 39a and 39b for now because the expected result for 
stddev failed:
+        //  expected: 1.5242630430075292, actual: 1.524263043007529.
+        //  Will change the comparison logic to detect floating-point numbers 
and compare
+        //  with epsilon
+        // "q39a",
+        // "q39b",
         "q40",
         "q41",
         "q42",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to