This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 5a11dd803 [VL] Support regr_slope aggregate function (#5216)
5a11dd803 is described below
commit 5a11dd803f8f05bcf57377470cbf1fa502a8b522
Author: Joey <[email protected]>
AuthorDate: Tue Apr 2 22:38:26 2024 +0800
[VL] Support regr_slope aggregate function (#5216)
---
.../org/apache/gluten/utils/CHExpressionUtil.scala | 1 +
.../execution/HashAggregateExecTransformer.scala | 6 +--
.../gluten/utils/VeloxIntermediateData.scala | 44 ++++++++++++++++++----
.../execution/VeloxAggregateFunctionsSuite.scala | 19 ++++++++++
.../substrait/SubstraitToVeloxPlanValidator.cc | 3 +-
docs/velox-backend-support-progress.md | 1 +
.../apache/gluten/expression/ExpressionNames.scala | 1 +
.../gluten/sql/shims/spark34/Spark34Shims.scala | 5 ++-
8 files changed, 66 insertions(+), 14 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index f630bc0e5..96dd71164 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -184,6 +184,7 @@ object CHExpressionUtil {
MAKE_YM_INTERVAL -> DefaultValidator(),
KURTOSIS -> DefaultValidator(),
REGR_R2 -> DefaultValidator(),
+ REGR_SLOPE -> DefaultValidator(),
TO_UTC_TIMESTAMP -> DefaultValidator(),
FROM_UTC_TIMESTAMP -> DefaultValidator()
)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index d3f9bd78f..0a9904206 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -147,7 +147,7 @@ abstract class HashAggregateExecTransformer(
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name,
attr.dataType)).unzip
val veloxOrders =
VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
- val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_))
+ val adjustedOrders =
sparkOrders.map(VeloxIntermediateData.getAttrIndex(veloxOrders, _))
sparkTypes.zipWithIndex.foreach {
case (sparkType, idx) =>
val veloxType = veloxTypes(adjustedOrders(idx))
@@ -380,7 +380,7 @@ abstract class HashAggregateExecTransformer(
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name,
attr.dataType)).unzip
val veloxOrders =
VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
- val adjustedOrders = veloxOrders.map(sparkOrders.indexOf(_))
+ val adjustedOrders = veloxOrders.map(o =>
sparkOrders.indexOf(o.head))
veloxTypes.zipWithIndex.foreach {
case (veloxType, idx) =>
val adjustedIdx = adjustedOrders(idx)
@@ -392,7 +392,7 @@ abstract class HashAggregateExecTransformer(
// have the column of m4, thus a placeholder m4 with a
value of 0 must be passed
// to Velox, and this value cannot be omitted. Velox will
always read m4 column
// when accessing the intermediate data.
- val extraAttr = AttributeReference(veloxOrders(idx),
veloxType)()
+ val extraAttr = AttributeReference(veloxOrders(idx).head,
veloxType)()
newInputAttributes += extraAttr
val lt = Literal.default(veloxType)
childNodes.add(ExpressionBuilder.makeLiteral(lt.value,
lt.dataType, false))
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
index e32bb8e86..149634a47 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
@@ -25,15 +25,26 @@ import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
object VeloxIntermediateData {
- // Agg functions with inconsistent ordering of intermediate data between
Velox and Spark.
+ // Agg functions with inconsistent ordering of intermediate data between
Velox and Spark. The
+ // strings in the Seq comes from the aggBufferAttributes of Spark's
aggregate function, and they
+ // are arranged in the order of fields in Velox's Accumulator. The reason
for using a
+ // two-dimensional Seq is that in some cases, a field in Velox will be
mapped to multiple
+ // Attributes in Spark's aggBufferAttributes. For example, the fourth field
of Velox's RegrSlope
+ // Accumulator is mapped to both xAvg and avg in Spark's RegrSlope
aggBufferAttributes. In this
+ // scenario, when passing the output of Spark's partial aggregation to
Velox, we only need to
+ // take one of them.
// Corr, RegrR2
- private val veloxCorrIntermediateDataOrder: Seq[String] =
- Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
+ private val veloxCorrIntermediateDataOrder: Seq[Seq[String]] =
+ Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg").map(Seq(_))
// CovPopulation, CovSample
- private val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n",
"xAvg", "yAvg")
+ private val veloxCovarIntermediateDataOrder: Seq[Seq[String]] =
+ Seq("ck", "n", "xAvg", "yAvg").map(Seq(_))
// Skewness, Kurtosis
- private val veloxCentralMomentAggIntermediateDataOrder: Seq[String] =
- Seq("n", "avg", "m2", "m3", "m4")
+ private val veloxCentralMomentAggIntermediateDataOrder: Seq[Seq[String]] =
+ Seq("n", "avg", "m2", "m3", "m4").map(Seq(_))
+ // RegrSlope
+ private val veloxRegrSlopeIntermediateDataOrder: Seq[Seq[String]] =
+ Seq("ck", "n", "m2", "xAvg:avg", "yAvg").map(attr => attr.split(":").toSeq)
// Agg functions with inconsistent types of intermediate data between Velox
and Spark.
// StddevSamp, StddevPop, VarianceSamp, VariancePop
@@ -47,6 +58,15 @@ object VeloxIntermediateData {
// Skewness, Kurtosis
private val veloxCentralMomentAggIntermediateTypes: Seq[DataType] =
Seq(LongType, DoubleType, DoubleType, DoubleType, DoubleType)
+ // RegrSlope
+ private val veloxRegrSlopeIntermediateTypes: Seq[DataType] =
+ Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType)
+
+ def getAttrIndex(intermediateDataOrder: Seq[Seq[String]], attr: String): Int
=
+ intermediateDataOrder.zipWithIndex
+ .find { case (innerSeq, _) => innerSeq.contains(attr) }
+ .map(_._2)
+ .getOrElse(-1)
/**
* Return the intermediate columns order of Velox aggregation functions,
with special matching
@@ -57,7 +77,7 @@ object VeloxIntermediateData {
* @return
* the intermediate columns order of Velox aggregation functions
*/
- def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = {
+ def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[Seq[String]]
= {
aggFunc match {
case _: PearsonCorrelation =>
veloxCorrIntermediateDataOrder
@@ -65,8 +85,14 @@ object VeloxIntermediateData {
veloxCovarIntermediateDataOrder
case _: Skewness | _: Kurtosis =>
veloxCentralMomentAggIntermediateDataOrder
+ // The reason for using class names to match aggFunc here is because
these aggFunc come from
+ // certain versions of Spark, and SparkShim is not dependent on the
backend-velox module. It
+ // is not convenient to include Velox-specific logic in SparkShim. Using
class names to match
+ // aggFunc is reliable in this case, as there are no cases of duplicate
names.
+ case _ if aggFunc.getClass.getSimpleName.equals("RegrSlope") =>
+ veloxRegrSlopeIntermediateDataOrder
case _ =>
- aggFunc.aggBufferAttributes.map(_.name)
+ aggFunc.aggBufferAttributes.map(_.name).map(Seq(_))
}
}
@@ -146,6 +172,8 @@ object VeloxIntermediateData {
Some(veloxVarianceIntermediateTypes)
case _: Skewness | _: Kurtosis =>
Some(veloxCentralMomentAggIntermediateTypes)
+ case _ if aggFunc.getClass.getSimpleName.equals("RegrSlope") =>
+ Some(veloxRegrSlopeIntermediateTypes)
case _ if aggFunc.aggBufferAttributes.size > 1 =>
Some(aggFunc.aggBufferAttributes.map(_.dataType))
case _ => None
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 4f8cc6c09..6d84d622f 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -390,6 +390,25 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
}
}
+ testWithSpecifiedSparkVersion("regr_slope", Some("3.4")) {
+ runQueryAndCompare("""
+ |select regr_slope(l_partkey, l_suppkey) from
lineitem;
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[HashAggregateExecTransformer]
+ }
+ runQueryAndCompare(
+ "select regr_slope(l_partkey, l_suppkey), count(distinct l_orderkey)
from lineitem") {
+ df =>
+ {
+ assert(
+ getExecutedPlan(df).count(
+ plan => {
+ plan.isInstanceOf[HashAggregateExecTransformer]
+ }) == 4)
+ }
+ }
+ }
+
test("first") {
runQueryAndCompare(s"""
|select first(l_linenumber), first(l_linenumber,
true) from lineitem;
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index cc2d531f5..a302701b4 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -1080,7 +1080,8 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::AggregateRel& ag
"covar_samp",
"approx_distinct",
"skewness",
- "kurtosis"};
+ "kurtosis",
+ "regr_slope"};
auto udfFuncs = UdfLoader::getInstance()->getRegisteredUdafNames();
diff --git a/docs/velox-backend-support-progress.md
b/docs/velox-backend-support-progress.md
index 0f22bc20e..28ac7218a 100644
--- a/docs/velox-backend-support-progress.md
+++ b/docs/velox-backend-support-progress.md
@@ -379,6 +379,7 @@ Gluten supports 199 functions. (Drag to right to see all
data types)
| min | min |
| S | | | | S | S | S | S
| S | | | | | | | |
| | | |
| min_by | |
| S | | | | | | |
| | | | | | | | |
| | | |
| regr_r2 | regr_r2 | regr_r2
| S | | | | S | S | S | S
| S | | | | | | | |
| | | |
+| regr_slope | regr_slope | regr_slope
| S | | | | S | S | S | S
| S | | | | | | | |
| | | |
| skewness | skewness | skewness
| S | | | | S | S | S | S
| S | | | | | | | |
| | | |
| some | |
| | | | | | | |
| | | | | | | | |
| | | |
| std,stddev | stddev |
| S | | | | S | S | S | S
| S | | | | | | | |
| | | |
diff --git
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 59a6810ab..dc1fd3733 100644
---
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -49,6 +49,7 @@ object ExpressionNames {
final val APPROX_PERCENTILE = "approx_percentile"
final val SKEWNESS = "skewness"
final val KURTOSIS = "kurtosis"
+ final val REGR_SLOPE = "regr_slope"
// Function names used by Substrait plan.
final val ADD = "add"
diff --git
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index 19264f97a..2178b1d17 100644
---
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate,
RegrR2}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -80,7 +80,8 @@ class Spark34Shims extends SparkShims {
override def aggregateExpressionMappings: Seq[Sig] = {
Seq(
- Sig[RegrR2](ExpressionNames.REGR_R2)
+ Sig[RegrR2](ExpressionNames.REGR_R2),
+ Sig[RegrSlope](ExpressionNames.REGR_SLOPE)
)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]