This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4710d62  feat: Port Datafusion Covariance to Comet (#234)
4710d62 is described below

commit 4710d62bae0030b52be6d466861d447b1d21ac47
Author: Huaxin Gao <[email protected]>
AuthorDate: Wed Apr 17 18:07:20 2024 -0700

    feat: Port Datafusion Covariance to Comet (#234)
    
    * feat: Port Datafusion Covariance to Comet
    
    * feat: Port Datafusion Covariance to Comet
    
    * fmt
    
    * update EXPRESSIONS.md
    
    * combine COVAR_SAMP and COVAR_POP
    
    * fix fmt
    
    * address comment
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 EXPRESSIONS.md                                     |   3 +
 .../execution/datafusion/expressions/covariance.rs | 311 +++++++++++++++++++++
 core/src/execution/datafusion/expressions/mod.rs   |   2 +
 core/src/execution/datafusion/expressions/stats.rs |  27 ++
 core/src/execution/datafusion/planner.rs           |  26 ++
 core/src/execution/proto/expr.proto                |  16 ++
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  39 ++-
 .../apache/comet/exec/CometAggregateSuite.scala    |  39 +++
 8 files changed, 462 insertions(+), 1 deletion(-)

diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md
index bc03be6..45c3684 100644
--- a/EXPRESSIONS.md
+++ b/EXPRESSIONS.md
@@ -101,3 +101,6 @@ The following Spark expressions are currently available:
     + BitAnd
     + BitOr
     + BitXor
+    + BoolAnd
+    + BoolOr
+    + Covariance
diff --git a/core/src/execution/datafusion/expressions/covariance.rs 
b/core/src/execution/datafusion/expressions/covariance.rs
new file mode 100644
index 0000000..5d0e550
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/covariance.rs
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::stats::StatsType;
+use arrow::{
+    array::{ArrayRef, Float64Array},
+    compute::cast,
+    datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{
+    downcast_value, unwrap_or_internal_err, DataFusionError, Result, 
ScalarValue,
+};
+use datafusion_physical_expr::{
+    aggregate::utils::down_cast_any_ref, expressions::format_state_name, 
AggregateExpr,
+    PhysicalExpr,
+};
+
+/// COVAR_SAMP and COVAR_POP aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation. 
The reason
+/// we have our own implementation is that DataFusion has UInt64 for 
state_field count,
+/// while Spark has Double for count.
+#[derive(Debug, Clone)]
+pub struct Covariance {
+    name: String,
+    expr1: Arc<dyn PhysicalExpr>,
+    expr2: Arc<dyn PhysicalExpr>,
+    stats_type: StatsType,
+}
+
+impl Covariance {
+    /// Create a new COVAR aggregate function
+    pub fn new(
+        expr1: Arc<dyn PhysicalExpr>,
+        expr2: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+        stats_type: StatsType,
+    ) -> Self {
+        // the result of covariance just support FLOAT64 data type.
+        assert!(matches!(data_type, DataType::Float64));
+        Self {
+            name: name.into(),
+            expr1,
+            expr2,
+            stats_type,
+        }
+    }
+}
+
+impl AggregateExpr for Covariance {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::Float64, true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(CovarianceAccumulator::try_new(self.stats_type)?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(
+                format_state_name(&self.name, "count"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean1"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean2"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "algo_const"),
+                DataType::Float64,
+                true,
+            ),
+        ])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr1.clone(), self.expr2.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for Covariance {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name
+                    && self.expr1.eq(&x.expr1)
+                    && self.expr2.eq(&x.expr2)
+                    && self.stats_type == x.stats_type
+            })
+            .unwrap_or(false)
+    }
+}
+
+/// An accumulator to compute covariance
+#[derive(Debug)]
+pub struct CovarianceAccumulator {
+    algo_const: f64,
+    mean1: f64,
+    mean2: f64,
+    count: f64,
+    stats_type: StatsType,
+}
+
+impl CovarianceAccumulator {
+    /// Creates a new `CovarianceAccumulator`
+    pub fn try_new(s_type: StatsType) -> Result<Self> {
+        Ok(Self {
+            algo_const: 0_f64,
+            mean1: 0_f64,
+            mean2: 0_f64,
+            count: 0_f64,
+            stats_type: s_type,
+        })
+    }
+
+    pub fn get_count(&self) -> f64 {
+        self.count
+    }
+
+    pub fn get_mean1(&self) -> f64 {
+        self.mean1
+    }
+
+    pub fn get_mean2(&self) -> f64 {
+        self.mean2
+    }
+
+    pub fn get_algo_const(&self) -> f64 {
+        self.algo_const
+    }
+}
+
+impl Accumulator for CovarianceAccumulator {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.count),
+            ScalarValue::from(self.mean1),
+            ScalarValue::from(self.mean2),
+            ScalarValue::from(self.algo_const),
+        ])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values1 = &cast(&values[0], &DataType::Float64)?;
+        let values2 = &cast(&values[1], &DataType::Float64)?;
+
+        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+        for i in 0..values1.len() {
+            let value1 = if values1.is_valid(i) {
+                arr1.next()
+            } else {
+                None
+            };
+            let value2 = if values2.is_valid(i) {
+                arr2.next()
+            } else {
+                None
+            };
+
+            if value1.is_none() || value2.is_none() {
+                continue;
+            }
+
+            let value1 = unwrap_or_internal_err!(value1);
+            let value2 = unwrap_or_internal_err!(value2);
+            let new_count = self.count + 1.0;
+            let delta1 = value1 - self.mean1;
+            let new_mean1 = delta1 / new_count + self.mean1;
+            let delta2 = value2 - self.mean2;
+            let new_mean2 = delta2 / new_count + self.mean2;
+            let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
+
+            self.count += 1.0;
+            self.mean1 = new_mean1;
+            self.mean2 = new_mean2;
+            self.algo_const = new_c;
+        }
+
+        Ok(())
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values1 = &cast(&values[0], &DataType::Float64)?;
+        let values2 = &cast(&values[1], &DataType::Float64)?;
+        let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+        let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+        for i in 0..values1.len() {
+            let value1 = if values1.is_valid(i) {
+                arr1.next()
+            } else {
+                None
+            };
+            let value2 = if values2.is_valid(i) {
+                arr2.next()
+            } else {
+                None
+            };
+
+            if value1.is_none() || value2.is_none() {
+                continue;
+            }
+
+            let value1 = unwrap_or_internal_err!(value1);
+            let value2 = unwrap_or_internal_err!(value2);
+
+            let new_count = self.count - 1.0;
+            let delta1 = self.mean1 - value1;
+            let new_mean1 = delta1 / new_count + self.mean1;
+            let delta2 = self.mean2 - value2;
+            let new_mean2 = delta2 / new_count + self.mean2;
+            let new_c = self.algo_const - delta1 * (new_mean2 - value2);
+
+            self.count -= 1.0;
+            self.mean1 = new_mean1;
+            self.mean2 = new_mean2;
+            self.algo_const = new_c;
+        }
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let counts = downcast_value!(states[0], Float64Array);
+        let means1 = downcast_value!(states[1], Float64Array);
+        let means2 = downcast_value!(states[2], Float64Array);
+        let cs = downcast_value!(states[3], Float64Array);
+
+        for i in 0..counts.len() {
+            let c = counts.value(i);
+            if c == 0.0 {
+                continue;
+            }
+            let new_count = self.count + c;
+            let new_mean1 = self.mean1 * self.count / new_count + 
means1.value(i) * c / new_count;
+            let new_mean2 = self.mean2 * self.count / new_count + 
means2.value(i) * c / new_count;
+            let delta1 = self.mean1 - means1.value(i);
+            let delta2 = self.mean2 - means2.value(i);
+            let new_c =
+                self.algo_const + cs.value(i) + delta1 * delta2 * self.count * 
c / new_count;
+
+            self.count = new_count;
+            self.mean1 = new_mean1;
+            self.mean2 = new_mean2;
+            self.algo_const = new_c;
+        }
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let count = match self.stats_type {
+            StatsType::Population => self.count,
+            StatsType::Sample => {
+                if self.count > 0.0 {
+                    self.count - 1.0
+                } else {
+                    self.count
+                }
+            }
+        };
+
+        if count == 0.0 {
+            Ok(ScalarValue::Float64(None))
+        } else {
+            Ok(ScalarValue::Float64(Some(self.algo_const / count)))
+        }
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 69cdf3e..799790c 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -27,6 +27,8 @@ pub use normalize_nan::NormalizeNaNAndZero;
 pub mod avg;
 pub mod avg_decimal;
 pub mod bloom_filter_might_contain;
+pub mod covariance;
+pub mod stats;
 pub mod strings;
 pub mod subquery;
 pub mod sum_decimal;
diff --git a/core/src/execution/datafusion/expressions/stats.rs 
b/core/src/execution/datafusion/expressions/stats.rs
new file mode 100644
index 0000000..1f4e64d
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/stats.rs
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// Enum used for differentiating population and sample for statistical 
functions
+#[derive(PartialEq, Eq, Debug, Clone, Copy)]
+pub enum StatsType {
+    /// Population
+    Population,
+    /// Sample
+    Sample,
+}
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 052ecc4..ca926bf 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -67,8 +67,10 @@ use crate::{
                 bloom_filter_might_contain::BloomFilterMightContain,
                 cast::Cast,
                 checkoverflow::CheckOverflow,
+                covariance::Covariance,
                 if_expr::IfExpr,
                 scalar_funcs::create_comet_physical_fun,
+                stats::StatsType,
                 strings::{Contains, EndsWith, Like, StartsWith, 
StringSpaceExec, SubstringExec},
                 subquery::Subquery,
                 sum_decimal::SumDecimal,
@@ -1193,6 +1195,30 @@ impl PhysicalPlanner {
                 let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 Ok(Arc::new(BitXor::new(child, "bit_xor", datatype)))
             }
+            AggExprStruct::CovSample(expr) => {
+                let child1 = self.create_expr(expr.child1.as_ref().unwrap(), 
schema.clone())?;
+                let child2 = self.create_expr(expr.child2.as_ref().unwrap(), 
schema.clone())?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(Covariance::new(
+                    child1,
+                    child2,
+                    "covariance",
+                    datatype,
+                    StatsType::Sample,
+                )))
+            }
+            AggExprStruct::CovPopulation(expr) => {
+                let child1 = self.create_expr(expr.child1.as_ref().unwrap(), 
schema.clone())?;
+                let child2 = self.create_expr(expr.child2.as_ref().unwrap(), 
schema.clone())?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(Covariance::new(
+                    child1,
+                    child2,
+                    "covariance_pop",
+                    datatype,
+                    StatsType::Population,
+                )))
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index 58f607f..1a6c29c 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -92,6 +92,8 @@ message AggExpr {
     BitAndAgg bitAndAgg = 9;
     BitOrAgg bitOrAgg = 10;
     BitXorAgg bitXorAgg = 11;
+    CovSample covSample = 12;
+    CovPopulation covPopulation = 13;
   }
 }
 
@@ -149,6 +151,20 @@ message BitXorAgg {
   DataType datatype = 2;
 }
 
+message CovSample {
+  Expr child1 = 1;
+  Expr child2 = 2;
+  bool null_on_divide_by_zero = 3;
+  DataType datatype = 4;
+}
+
+message CovPopulation {
+  Expr child1 = 1;
+  Expr child2 = 2;
+  bool null_on_divide_by_zero = 3;
+  DataType datatype = 4;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 172a5b5..b62c222 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, 
Partial, Sum}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, Sum}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
@@ -388,7 +388,44 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
         } else {
           None
         }
+      case cov @ CovSample(child1, child2, _) =>
+        val child1Expr = exprToProto(child1, inputs, binding)
+        val child2Expr = exprToProto(child2, inputs, binding)
+        val dataType = serializeDataType(cov.dataType)
 
+        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
+          val covBuilder = ExprOuterClass.CovSample.newBuilder()
+          covBuilder.setChild1(child1Expr.get)
+          covBuilder.setChild2(child2Expr.get)
+          covBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setCovSample(covBuilder)
+              .build())
+        } else {
+          None
+        }
+      case cov @ CovPopulation(child1, child2, _) =>
+        val child1Expr = exprToProto(child1, inputs, binding)
+        val child2Expr = exprToProto(child2, inputs, binding)
+        val dataType = serializeDataType(cov.dataType)
+
+        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
+          val covBuilder = ExprOuterClass.CovPopulation.newBuilder()
+          covBuilder.setChild1(child1Expr.get)
+          covBuilder.setChild2(child2Expr.get)
+          covBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setCovPopulation(covBuilder)
+              .build())
+        } else {
+          None
+        }
       case fn =>
         emitWarning(s"unsupported Spark aggregate function: $fn")
         None
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index b5ed5f4..09c7151 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1065,6 +1065,45 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("covar_pop and covar_samp") {
+    withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { cometColumnShuffleEnabled =>
+        withSQLConf(
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
cometColumnShuffleEnabled.toString) {
+          Seq(true, false).foreach { dictionary =>
+            withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+              val table = "test"
+              withTable(table) {
+                sql(
+                  s"create table $table(col1 int, col2 int, col3 int, col4 
float, col5 double," +
+                    " col6 double, col7 int) using parquet")
+                sql(s"insert into $table values(1, 4, null, 1.1, 2.2, null, 
1)," +
+                  " (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null, 
2)")
+                val expectedNumOfCometAggregates = 2
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT covar_samp(col1, col2), covar_samp(col1, col3), 
covar_samp(col4, col5)," +
+                    " covar_samp(col4, col6) FROM test",
+                  expectedNumOfCometAggregates)
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT covar_pop(col1, col2), covar_pop(col1, col3), 
covar_pop(col4, col5)," +
+                    " covar_pop(col4, col6) FROM test",
+                  expectedNumOfCometAggregates)
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT covar_samp(col1, col2), covar_samp(col1, col3), 
covar_samp(col4, col5)," +
+                    " covar_samp(col4, col6) FROM test GROUP BY col7",
+                  expectedNumOfCometAggregates)
+                checkSparkAnswerAndNumOfAggregates(
+                  "SELECT covar_pop(col1, col2), covar_pop(col1, col3), 
covar_pop(col4, col5)," +
+                    " covar_pop(col4, col6) FROM test GROUP BY col7",
+                  expectedNumOfCometAggregates)
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)

Reply via email to