This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a11dd803 [VL] Support regr_slope aggregate function (#5216)
5a11dd803 is described below

commit 5a11dd803f8f05bcf57377470cbf1fa502a8b522
Author: Joey <[email protected]>
AuthorDate: Tue Apr 2 22:38:26 2024 +0800

    [VL] Support regr_slope aggregate function (#5216)
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |  1 +
 .../execution/HashAggregateExecTransformer.scala   |  6 +--
 .../gluten/utils/VeloxIntermediateData.scala       | 44 ++++++++++++++++++----
 .../execution/VeloxAggregateFunctionsSuite.scala   | 19 ++++++++++
 .../substrait/SubstraitToVeloxPlanValidator.cc     |  3 +-
 docs/velox-backend-support-progress.md             |  1 +
 .../apache/gluten/expression/ExpressionNames.scala |  1 +
 .../gluten/sql/shims/spark34/Spark34Shims.scala    |  5 ++-
 8 files changed, 66 insertions(+), 14 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index f630bc0e5..96dd71164 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -184,6 +184,7 @@ object CHExpressionUtil {
     MAKE_YM_INTERVAL -> DefaultValidator(),
     KURTOSIS -> DefaultValidator(),
     REGR_R2 -> DefaultValidator(),
+    REGR_SLOPE -> DefaultValidator(),
     TO_UTC_TIMESTAMP -> DefaultValidator(),
     FROM_UTC_TIMESTAMP -> DefaultValidator()
   )
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index d3f9bd78f..0a9904206 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -147,7 +147,7 @@ abstract class HashAggregateExecTransformer(
           val (sparkOrders, sparkTypes) =
             aggFunc.aggBufferAttributes.map(attr => (attr.name, 
attr.dataType)).unzip
           val veloxOrders = 
VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
-          val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_))
+          val adjustedOrders = 
sparkOrders.map(VeloxIntermediateData.getAttrIndex(veloxOrders, _))
           sparkTypes.zipWithIndex.foreach {
             case (sparkType, idx) =>
               val veloxType = veloxTypes(adjustedOrders(idx))
@@ -380,7 +380,7 @@ abstract class HashAggregateExecTransformer(
               val (sparkOrders, sparkTypes) =
                 aggFunc.aggBufferAttributes.map(attr => (attr.name, 
attr.dataType)).unzip
               val veloxOrders = 
VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
-              val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
+              val adjustedOrders = veloxOrders.map(o => 
sparkOrders.indexOf(o.head))
               veloxTypes.zipWithIndex.foreach {
                 case (veloxType, idx) =>
                   val adjustedIdx = adjustedOrders(idx)
@@ -392,7 +392,7 @@ abstract class HashAggregateExecTransformer(
                     // have the column of m4, thus a placeholder m4 with a 
value of 0 must be passed
                     // to Velox, and this value cannot be omitted. Velox will 
always read m4 column
                     // when accessing the intermediate data.
-                    val extraAttr = AttributeReference(veloxOrders(idx), 
veloxType)()
+                    val extraAttr = AttributeReference(veloxOrders(idx).head, 
veloxType)()
                     newInputAttributes += extraAttr
                     val lt = Literal.default(veloxType)
                     childNodes.add(ExpressionBuilder.makeLiteral(lt.value, 
lt.dataType, false))
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
index e32bb8e86..149634a47 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
@@ -25,15 +25,26 @@ import org.apache.spark.sql.types._
 import scala.collection.JavaConverters._
 
 object VeloxIntermediateData {
-  // Agg functions with inconsistent ordering of intermediate data between 
Velox and Spark.
+  // Agg functions with inconsistent ordering of intermediate data between 
Velox and Spark. The
+  // strings in the Seq comes from the aggBufferAttributes of Spark's 
aggregate function, and they
+  // are arranged in the order of fields in Velox's Accumulator. The reason 
for using a
+  // two-dimensional Seq is that in some cases, a field in Velox will be 
mapped to multiple
+  // Attributes in Spark's aggBufferAttributes. For example, the fourth field 
of Velox's RegrSlope
+  // Accumulator is mapped to both xAvg and avg in Spark's RegrSlope 
aggBufferAttributes. In this
+  // scenario, when passing the output of Spark's partial aggregation to 
Velox, we only need to
+  // take one of them.
   // Corr, RegrR2
-  private val veloxCorrIntermediateDataOrder: Seq[String] =
-    Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
+  private val veloxCorrIntermediateDataOrder: Seq[Seq[String]] =
+    Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg").map(Seq(_))
   // CovPopulation, CovSample
-  private val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", 
"xAvg", "yAvg")
+  private val veloxCovarIntermediateDataOrder: Seq[Seq[String]] =
+    Seq("ck", "n", "xAvg", "yAvg").map(Seq(_))
   // Skewness, Kurtosis
-  private val veloxCentralMomentAggIntermediateDataOrder: Seq[String] =
-    Seq("n", "avg", "m2", "m3", "m4")
+  private val veloxCentralMomentAggIntermediateDataOrder: Seq[Seq[String]] =
+    Seq("n", "avg", "m2", "m3", "m4").map(Seq(_))
+  // RegrSlope
+  private val veloxRegrSlopeIntermediateDataOrder: Seq[Seq[String]] =
+    Seq("ck", "n", "m2", "xAvg:avg", "yAvg").map(attr => attr.split(":").toSeq)
 
   // Agg functions with inconsistent types of intermediate data between Velox 
and Spark.
   // StddevSamp, StddevPop, VarianceSamp, VariancePop
@@ -47,6 +58,15 @@ object VeloxIntermediateData {
   // Skewness, Kurtosis
   private val veloxCentralMomentAggIntermediateTypes: Seq[DataType] =
     Seq(LongType, DoubleType, DoubleType, DoubleType, DoubleType)
+  // RegrSlope
+  private val veloxRegrSlopeIntermediateTypes: Seq[DataType] =
+    Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType)
+
+  def getAttrIndex(intermediateDataOrder: Seq[Seq[String]], attr: String): Int 
=
+    intermediateDataOrder.zipWithIndex
+      .find { case (innerSeq, _) => innerSeq.contains(attr) }
+      .map(_._2)
+      .getOrElse(-1)
 
   /**
    * Return the intermediate columns order of Velox aggregation functions, 
with special matching
@@ -57,7 +77,7 @@ object VeloxIntermediateData {
    * @return
    *   the intermediate columns order of Velox aggregation functions
    */
-  def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = {
+  def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[Seq[String]] 
= {
     aggFunc match {
       case _: PearsonCorrelation =>
         veloxCorrIntermediateDataOrder
@@ -65,8 +85,14 @@ object VeloxIntermediateData {
         veloxCovarIntermediateDataOrder
       case _: Skewness | _: Kurtosis =>
         veloxCentralMomentAggIntermediateDataOrder
+      // The reason for using class names to match aggFunc here is because 
these aggFunc come from
+      // certain versions of Spark, and SparkShim is not dependent on the 
backend-velox module. It
+      // is not convenient to include Velox-specific logic in SparkShim. Using 
class names to match
+      // aggFunc is reliable in this case, as there are no cases of duplicate 
names.
+      case _ if aggFunc.getClass.getSimpleName.equals("RegrSlope") =>
+        veloxRegrSlopeIntermediateDataOrder
       case _ =>
-        aggFunc.aggBufferAttributes.map(_.name)
+        aggFunc.aggBufferAttributes.map(_.name).map(Seq(_))
     }
   }
 
@@ -146,6 +172,8 @@ object VeloxIntermediateData {
           Some(veloxVarianceIntermediateTypes)
         case _: Skewness | _: Kurtosis =>
           Some(veloxCentralMomentAggIntermediateTypes)
+        case _ if aggFunc.getClass.getSimpleName.equals("RegrSlope") =>
+          Some(veloxRegrSlopeIntermediateTypes)
         case _ if aggFunc.aggBufferAttributes.size > 1 =>
           Some(aggFunc.aggBufferAttributes.map(_.dataType))
         case _ => None
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 4f8cc6c09..6d84d622f 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -390,6 +390,25 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
     }
   }
 
+  testWithSpecifiedSparkVersion("regr_slope", Some("3.4")) {
+    runQueryAndCompare("""
+                         |select regr_slope(l_partkey, l_suppkey) from 
lineitem;
+                         |""".stripMargin) {
+      checkGlutenOperatorMatch[HashAggregateExecTransformer]
+    }
+    runQueryAndCompare(
+      "select regr_slope(l_partkey, l_suppkey), count(distinct l_orderkey) 
from lineitem") {
+      df =>
+        {
+          assert(
+            getExecutedPlan(df).count(
+              plan => {
+                plan.isInstanceOf[HashAggregateExecTransformer]
+              }) == 4)
+        }
+    }
+  }
+
   test("first") {
     runQueryAndCompare(s"""
                           |select first(l_linenumber), first(l_linenumber, 
true) from lineitem;
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index cc2d531f5..a302701b4 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -1080,7 +1080,8 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::AggregateRel& ag
       "covar_samp",
       "approx_distinct",
       "skewness",
-      "kurtosis"};
+      "kurtosis",
+      "regr_slope"};
 
   auto udfFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();
 
diff --git a/docs/velox-backend-support-progress.md 
b/docs/velox-backend-support-progress.md
index 0f22bc20e..28ac7218a 100644
--- a/docs/velox-backend-support-progress.md
+++ b/docs/velox-backend-support-progress.md
@@ -379,6 +379,7 @@ Gluten supports 199 functions. (Drag to right to see all 
data types)
 | min                           | min                    |                     
  | S      |                        |         |      | S     | S   | S    | S   
  | S      |      |           |        |         |      |        |          |   
    |      |        |      |
 | min_by                        |                        |                     
  | S      |                        |         |      |       |     |      |     
  |        |      |           |        |         |      |        |          |   
    |      |        |      |
 | regr_r2                       | regr_r2                | regr_r2             
  | S      |                        |         |      | S     | S   | S    | S   
  | S      |      |           |        |         |      |        |          |   
    |      |        |      |
+| regr_slope                    | regr_slope             | regr_slope          
  | S      |                        |         |      | S     | S   | S    | S   
  | S      |      |           |        |         |      |        |          |   
    |      |        |      |
 | skewness                      | skewness               | skewness            
  | S      |                        |         |      | S     | S   | S    | S   
  | S      |      |           |        |         |      |        |          |   
    |      |        |      |
 | some                          |                        |                     
  |        |                        |         |      |       |     |      |     
  |        |      |           |        |         |      |        |          |   
    |      |        |      |
 | std,stddev                    | stddev                 |                     
  | S      |                        |         |      | S     | S   | S    | S   
  | S      |      |           |        |         |      |        |          |   
    |      |        |      |
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 59a6810ab..dc1fd3733 100644
--- 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++ 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -49,6 +49,7 @@ object ExpressionNames {
   final val APPROX_PERCENTILE = "approx_percentile"
   final val SKEWNESS = "skewness"
   final val KURTOSIS = "kurtosis"
+  final val REGR_SLOPE = "regr_slope"
 
   // Function names used by Substrait plan.
   final val ADD = "add"
diff --git 
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
 
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index 19264f97a..2178b1d17 100644
--- 
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++ 
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, 
RegrR2}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution, KeyGroupedPartitioning, Partitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -80,7 +80,8 @@ class Spark34Shims extends SparkShims {
 
   override def aggregateExpressionMappings: Seq[Sig] = {
     Seq(
-      Sig[RegrR2](ExpressionNames.REGR_R2)
+      Sig[RegrR2](ExpressionNames.REGR_R2),
+      Sig[RegrSlope](ExpressionNames.REGR_SLOPE)
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to