This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4710d62 feat: Port Datafusion Covariance to Comet (#234)
4710d62 is described below
commit 4710d62bae0030b52be6d466861d447b1d21ac47
Author: Huaxin Gao <[email protected]>
AuthorDate: Wed Apr 17 18:07:20 2024 -0700
feat: Port Datafusion Covariance to Comet (#234)
* feat: Port Datafusion Covariance to Comet
* feat: Port Datafusion Covariance to Comet
* fmt
* update EXPRESSIONS.md
* combine COVAR_SAMP and COVAR_POP
* fix fmt
* address comment
---------
Co-authored-by: Huaxin Gao <[email protected]>
Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
EXPRESSIONS.md | 3 +
.../execution/datafusion/expressions/covariance.rs | 311 +++++++++++++++++++++
core/src/execution/datafusion/expressions/mod.rs | 2 +
core/src/execution/datafusion/expressions/stats.rs | 27 ++
core/src/execution/datafusion/planner.rs | 26 ++
core/src/execution/proto/expr.proto | 16 ++
.../org/apache/comet/serde/QueryPlanSerde.scala | 39 ++-
.../apache/comet/exec/CometAggregateSuite.scala | 39 +++
8 files changed, 462 insertions(+), 1 deletion(-)
diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md
index bc03be6..45c3684 100644
--- a/EXPRESSIONS.md
+++ b/EXPRESSIONS.md
@@ -101,3 +101,6 @@ The following Spark expressions are currently available:
+ BitAnd
+ BitOr
+ BitXor
+ + BoolAnd
+ + BoolOr
+ + Covariance
diff --git a/core/src/execution/datafusion/expressions/covariance.rs
b/core/src/execution/datafusion/expressions/covariance.rs
new file mode 100644
index 0000000..5d0e550
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/covariance.rs
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::stats::StatsType;
+use arrow::{
+ array::{ArrayRef, Float64Array},
+ compute::cast,
+ datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{
+ downcast_value, unwrap_or_internal_err, DataFusionError, Result,
ScalarValue,
+};
+use datafusion_physical_expr::{
+ aggregate::utils::down_cast_any_ref, expressions::format_state_name,
AggregateExpr,
+ PhysicalExpr,
+};
+
+/// COVAR_SAMP and COVAR_POP aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation.
The reason
+/// we have our own implementation is that DataFusion has UInt64 for
state_field count,
+/// while Spark has Double for count.
+#[derive(Debug, Clone)]
+pub struct Covariance {
+ name: String,
+ expr1: Arc<dyn PhysicalExpr>,
+ expr2: Arc<dyn PhysicalExpr>,
+ stats_type: StatsType,
+}
+
+impl Covariance {
+ /// Create a new COVAR aggregate function
+ pub fn new(
+ expr1: Arc<dyn PhysicalExpr>,
+ expr2: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ stats_type: StatsType,
+ ) -> Self {
+ // the result of covariance just support FLOAT64 data type.
+ assert!(matches!(data_type, DataType::Float64));
+ Self {
+ name: name.into(),
+ expr1,
+ expr2,
+ stats_type,
+ }
+ }
+}
+
+impl AggregateExpr for Covariance {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::Float64, true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(CovarianceAccumulator::try_new(self.stats_type)?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ format_state_name(&self.name, "count"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean1"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean2"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "algo_const"),
+ DataType::Float64,
+ true,
+ ),
+ ])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr1.clone(), self.expr2.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl PartialEq<dyn Any> for Covariance {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| {
+ self.name == x.name
+ && self.expr1.eq(&x.expr1)
+ && self.expr2.eq(&x.expr2)
+ && self.stats_type == x.stats_type
+ })
+ .unwrap_or(false)
+ }
+}
+
+/// An accumulator to compute covariance
+#[derive(Debug)]
+pub struct CovarianceAccumulator {
+ algo_const: f64,
+ mean1: f64,
+ mean2: f64,
+ count: f64,
+ stats_type: StatsType,
+}
+
+impl CovarianceAccumulator {
+ /// Creates a new `CovarianceAccumulator`
+ pub fn try_new(s_type: StatsType) -> Result<Self> {
+ Ok(Self {
+ algo_const: 0_f64,
+ mean1: 0_f64,
+ mean2: 0_f64,
+ count: 0_f64,
+ stats_type: s_type,
+ })
+ }
+
+ pub fn get_count(&self) -> f64 {
+ self.count
+ }
+
+ pub fn get_mean1(&self) -> f64 {
+ self.mean1
+ }
+
+ pub fn get_mean2(&self) -> f64 {
+ self.mean2
+ }
+
+ pub fn get_algo_const(&self) -> f64 {
+ self.algo_const
+ }
+}
+
+impl Accumulator for CovarianceAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean1),
+ ScalarValue::from(self.mean2),
+ ScalarValue::from(self.algo_const),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values1 = &cast(&values[0], &DataType::Float64)?;
+ let values2 = &cast(&values[1], &DataType::Float64)?;
+
+ let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+ let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+ for i in 0..values1.len() {
+ let value1 = if values1.is_valid(i) {
+ arr1.next()
+ } else {
+ None
+ };
+ let value2 = if values2.is_valid(i) {
+ arr2.next()
+ } else {
+ None
+ };
+
+ if value1.is_none() || value2.is_none() {
+ continue;
+ }
+
+ let value1 = unwrap_or_internal_err!(value1);
+ let value2 = unwrap_or_internal_err!(value2);
+ let new_count = self.count + 1.0;
+ let delta1 = value1 - self.mean1;
+ let new_mean1 = delta1 / new_count + self.mean1;
+ let delta2 = value2 - self.mean2;
+ let new_mean2 = delta2 / new_count + self.mean2;
+ let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
+
+ self.count += 1.0;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values1 = &cast(&values[0], &DataType::Float64)?;
+ let values2 = &cast(&values[1], &DataType::Float64)?;
+ let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+ let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+ for i in 0..values1.len() {
+ let value1 = if values1.is_valid(i) {
+ arr1.next()
+ } else {
+ None
+ };
+ let value2 = if values2.is_valid(i) {
+ arr2.next()
+ } else {
+ None
+ };
+
+ if value1.is_none() || value2.is_none() {
+ continue;
+ }
+
+ let value1 = unwrap_or_internal_err!(value1);
+ let value2 = unwrap_or_internal_err!(value2);
+
+ let new_count = self.count - 1.0;
+ let delta1 = self.mean1 - value1;
+ let new_mean1 = delta1 / new_count + self.mean1;
+ let delta2 = self.mean2 - value2;
+ let new_mean2 = delta2 / new_count + self.mean2;
+ let new_c = self.algo_const - delta1 * (new_mean2 - value2);
+
+ self.count -= 1.0;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], Float64Array);
+ let means1 = downcast_value!(states[1], Float64Array);
+ let means2 = downcast_value!(states[2], Float64Array);
+ let cs = downcast_value!(states[3], Float64Array);
+
+ for i in 0..counts.len() {
+ let c = counts.value(i);
+ if c == 0.0 {
+ continue;
+ }
+ let new_count = self.count + c;
+ let new_mean1 = self.mean1 * self.count / new_count +
means1.value(i) * c / new_count;
+ let new_mean2 = self.mean2 * self.count / new_count +
means2.value(i) * c / new_count;
+ let delta1 = self.mean1 - means1.value(i);
+ let delta2 = self.mean2 - means2.value(i);
+ let new_c =
+ self.algo_const + cs.value(i) + delta1 * delta2 * self.count *
c / new_count;
+
+ self.count = new_count;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let count = match self.stats_type {
+ StatsType::Population => self.count,
+ StatsType::Sample => {
+ if self.count > 0.0 {
+ self.count - 1.0
+ } else {
+ self.count
+ }
+ }
+ };
+
+ if count == 0.0 {
+ Ok(ScalarValue::Float64(None))
+ } else {
+ Ok(ScalarValue::Float64(Some(self.algo_const / count)))
+ }
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
diff --git a/core/src/execution/datafusion/expressions/mod.rs
b/core/src/execution/datafusion/expressions/mod.rs
index 69cdf3e..799790c 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -27,6 +27,8 @@ pub use normalize_nan::NormalizeNaNAndZero;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
+pub mod covariance;
+pub mod stats;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
diff --git a/core/src/execution/datafusion/expressions/stats.rs
b/core/src/execution/datafusion/expressions/stats.rs
new file mode 100644
index 0000000..1f4e64d
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/stats.rs
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// Enum used for differentiating population and sample for statistical
functions
+#[derive(PartialEq, Eq, Debug, Clone, Copy)]
+pub enum StatsType {
+ /// Population
+ Population,
+ /// Sample
+ Sample,
+}
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index 052ecc4..ca926bf 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -67,8 +67,10 @@ use crate::{
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
+ covariance::Covariance,
if_expr::IfExpr,
scalar_funcs::create_comet_physical_fun,
+ stats::StatsType,
strings::{Contains, EndsWith, Like, StartsWith,
StringSpaceExec, SubstringExec},
subquery::Subquery,
sum_decimal::SumDecimal,
@@ -1193,6 +1195,30 @@ impl PhysicalPlanner {
let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(BitXor::new(child, "bit_xor", datatype)))
}
+ AggExprStruct::CovSample(expr) => {
+ let child1 = self.create_expr(expr.child1.as_ref().unwrap(),
schema.clone())?;
+ let child2 = self.create_expr(expr.child2.as_ref().unwrap(),
schema.clone())?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(Covariance::new(
+ child1,
+ child2,
+ "covariance",
+ datatype,
+ StatsType::Sample,
+ )))
+ }
+ AggExprStruct::CovPopulation(expr) => {
+ let child1 = self.create_expr(expr.child1.as_ref().unwrap(),
schema.clone())?;
+ let child2 = self.create_expr(expr.child2.as_ref().unwrap(),
schema.clone())?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(Covariance::new(
+ child1,
+ child2,
+ "covariance_pop",
+ datatype,
+ StatsType::Population,
+ )))
+ }
}
}
diff --git a/core/src/execution/proto/expr.proto
b/core/src/execution/proto/expr.proto
index 58f607f..1a6c29c 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -92,6 +92,8 @@ message AggExpr {
BitAndAgg bitAndAgg = 9;
BitOrAgg bitOrAgg = 10;
BitXorAgg bitXorAgg = 11;
+ CovSample covSample = 12;
+ CovPopulation covPopulation = 13;
}
}
@@ -149,6 +151,20 @@ message BitXorAgg {
DataType datatype = 2;
}
+message CovSample {
+ Expr child1 = 1;
+ Expr child2 = 2;
+ bool null_on_divide_by_zero = 3;
+ DataType datatype = 4;
+}
+
+message CovPopulation {
+ Expr child1 = 1;
+ Expr child2 = 2;
+ bool null_on_divide_by_zero = 3;
+ DataType datatype = 4;
+}
+
message Literal {
oneof value {
bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 172a5b5..b62c222 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min,
Partial, Sum}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample,
Final, First, Last, Max, Min, Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
@@ -388,7 +388,44 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
} else {
None
}
+ case cov @ CovSample(child1, child2, _) =>
+ val child1Expr = exprToProto(child1, inputs, binding)
+ val child2Expr = exprToProto(child2, inputs, binding)
+ val dataType = serializeDataType(cov.dataType)
+ if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
+ val covBuilder = ExprOuterClass.CovSample.newBuilder()
+ covBuilder.setChild1(child1Expr.get)
+ covBuilder.setChild2(child2Expr.get)
+ covBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCovSample(covBuilder)
+ .build())
+ } else {
+ None
+ }
+ case cov @ CovPopulation(child1, child2, _) =>
+ val child1Expr = exprToProto(child1, inputs, binding)
+ val child2Expr = exprToProto(child2, inputs, binding)
+ val dataType = serializeDataType(cov.dataType)
+
+ if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
+ val covBuilder = ExprOuterClass.CovPopulation.newBuilder()
+ covBuilder.setChild1(child1Expr.get)
+ covBuilder.setChild2(child2Expr.get)
+ covBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCovPopulation(covBuilder)
+ .build())
+ } else {
+ None
+ }
case fn =>
emitWarning(s"unsupported Spark aggregate function: $fn")
None
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index b5ed5f4..09c7151 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1065,6 +1065,45 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("covar_pop and covar_samp") {
+ withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+ Seq(true, false).foreach { cometColumnShuffleEnabled =>
+ withSQLConf(
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key ->
cometColumnShuffleEnabled.toString) {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(
+ s"create table $table(col1 int, col2 int, col3 int, col4
float, col5 double," +
+ " col6 double, col7 int) using parquet")
+ sql(s"insert into $table values(1, 4, null, 1.1, 2.2, null,
1)," +
+ " (2, 5, 6, 3.4, 5.6, null, 1), (3, 6, null, 7.9, 2.4, null,
2)")
+ val expectedNumOfCometAggregates = 2
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT covar_samp(col1, col2), covar_samp(col1, col3),
covar_samp(col4, col5)," +
+ " covar_samp(col4, col6) FROM test",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT covar_pop(col1, col2), covar_pop(col1, col3),
covar_pop(col4, col5)," +
+ " covar_pop(col4, col6) FROM test",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT covar_samp(col1, col2), covar_samp(col1, col3),
covar_samp(col4, col5)," +
+ " covar_samp(col4, col6) FROM test GROUP BY col7",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT covar_pop(col1, col2), covar_pop(col1, col3),
covar_pop(col4, col5)," +
+ " covar_pop(col4, col6) FROM test GROUP BY col7",
+ expectedNumOfCometAggregates)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)