This is an automated email from the ASF dual-hosted git repository.

zhli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a81010d8 [VL] Enable partial merge mode for HLL (#5754)
2a81010d8 is described below

commit 2a81010d864ba9d716cc3d974c2ccac8354df5c6
Author: Zhen Li <[email protected]>
AuthorDate: Wed May 29 11:13:08 2024 +0800

    [VL] Enable partial merge mode for HLL (#5754)
    
    [VL] Enabl partial merge mode for HLL.
---
 .../execution/HashAggregateExecTransformer.scala   | 37 ++++------------------
 .../apache/gluten/extension/HLLRewriteRule.scala   | 22 +++----------
 .../execution/VeloxAggregateFunctionsSuite.scala   | 34 +++-----------------
 3 files changed, 14 insertions(+), 79 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 01ab56881..4f33ae7c7 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -20,7 +20,6 @@ import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.expression._
 import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.expression.aggregate.HLLAdapter
 import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
 import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
 import org.apache.gluten.substrait.expression.{AggregateFunctionNode, 
ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
@@ -74,20 +73,6 @@ abstract class HashAggregateExecTransformer(
     TransformContext(childCtx.outputAttributes, output, relNode)
   }
 
-  override protected def checkAggFuncModeSupport(
-      aggFunc: AggregateFunction,
-      mode: AggregateMode): Boolean = {
-    aggFunc match {
-      case _: HLLAdapter =>
-        mode match {
-          case Partial | Final => true
-          case _ => false
-        }
-      case _ =>
-        super.checkAggFuncModeSupport(aggFunc, mode)
-    }
-  }
-
   // Return whether the outputs partial aggregation should be combined for 
Velox computing.
   // When the partial outputs are multiple-column, row construct is needed.
   private def rowConstructNeeded(aggregateExpressions: 
Seq[AggregateExpression]): Boolean = {
@@ -241,21 +226,21 @@ abstract class HashAggregateExecTransformer(
     }
 
     aggregateFunction match {
-      case hllAdapter: HLLAdapter =>
+      case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
+        generateMergeCompanionNode()
+      case _ =>
         aggregateMode match {
-          case Partial =>
-            // For Partial mode output type is binary.
+          case Partial | PartialMerge =>
             val partialNode = ExpressionBuilder.makeAggregateFunction(
               VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, 
aggregateMode),
               childrenNodeList,
               modeKeyWord,
               ConverterUtils.getTypeNode(
-                hllAdapter.inputAggBufferAttributes.head.dataType,
-                hllAdapter.inputAggBufferAttributes.head.nullable)
+                aggregateFunction.inputAggBufferAttributes.head.dataType,
+                aggregateFunction.inputAggBufferAttributes.head.nullable)
             )
             aggregateNodeList.add(partialNode)
           case Final =>
-            // For Final mode output type is long.
             val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
               VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, 
aggregateMode),
               childrenNodeList,
@@ -266,16 +251,6 @@ abstract class HashAggregateExecTransformer(
           case other =>
             throw new GlutenNotSupportException(s"$other is not supported.")
         }
-      case _ if aggregateFunction.aggBufferAttributes.size > 1 =>
-        generateMergeCompanionNode()
-      case _ =>
-        val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
-          VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, 
aggregateMode),
-          childrenNodeList,
-          modeKeyWord,
-          ConverterUtils.getTypeNode(aggregateFunction.dataType, 
aggregateFunction.nullable)
-        )
-        aggregateNodeList.add(aggFunctionNode)
     }
   }
 
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
index 03819fc10..7bae64ff8 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
@@ -35,7 +35,7 @@ case class HLLRewriteRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
           case hllExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _, 
_, _)
               if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction 
&&
                 GlutenConfig.getConf.enableColumnarHashAgg &&
-                !hasDistinctAggregateFunc(a) && 
isDataTypeSupported(hll.child.dataType) =>
+                isDataTypeSupported(hll.child.dataType) =>
             AggregateExpression(
               HLLAdapter(
                 hll.child,
@@ -51,29 +51,15 @@ case class HLLRewriteRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
     }
   }
 
-  private def hasDistinctAggregateFunc(agg: Aggregate): Boolean = {
-    agg.aggregateExpressions
-      .flatMap(_.collect { case ae: AggregateExpression => ae })
-      .exists(_.isDistinct)
-  }
-
   private def isDataTypeSupported(dataType: DataType): Boolean = {
     // HLL in velox only supports below data types. we should not offload HLL 
to velox, if
     // child's data type is not supported. This prevents the case only partail 
agg is fallbacked.
     // As spark and velox have different HLL binary formats, HLL binary 
generated by spark can't
     // be parsed by velox, it would cause the error: 'Unexpected type of HLL'.
     dataType match {
-      case BooleanType => true
-      case ByteType => true
-      case _: CharType => true
-      case DateType => true
-      case DoubleType => true
-      case FloatType => true
-      case IntegerType => true
-      case LongType => true
-      case ShortType => true
-      case StringType => true
-      case _: DecimalType => true
+      case BooleanType | ByteType | ShortType | IntegerType | LongType | 
FloatType | DoubleType |
+          StringType | _: CharType | _: DecimalType | DateType =>
+        true
       case _ => false
     }
   }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index ffed63731..4f6f4eb22 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -553,41 +553,15 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
   }
 
   test("approx_count_distinct") {
-    runQueryAndCompare("""
-                         |select approx_count_distinct(l_shipmode) from 
lineitem;
-                         |""".stripMargin) {
-      checkGlutenOperatorMatch[HashAggregateExecTransformer]
-    }
     runQueryAndCompare(
-      "select approx_count_distinct(l_partkey), count(distinct l_orderkey) 
from lineitem") {
-      df =>
-        {
-          assert(
-            getExecutedPlan(df).count(
-              plan => {
-                plan.isInstanceOf[HashAggregateExecTransformer]
-              }) == 0)
-        }
-    }
-  }
-
-  test("approx_count_distinct decimal") {
-    // The data type of l_discount is decimal.
-    runQueryAndCompare("""
-                         |select approx_count_distinct(l_discount) from 
lineitem;
-                         |""".stripMargin) {
+      """
+        |select approx_count_distinct(l_shipmode), 
approx_count_distinct(l_discount) from lineitem;
+        |""".stripMargin) {
       checkGlutenOperatorMatch[HashAggregateExecTransformer]
     }
     runQueryAndCompare(
       "select approx_count_distinct(l_discount), count(distinct l_orderkey) 
from lineitem") {
-      df =>
-        {
-          assert(
-            getExecutedPlan(df).count(
-              plan => {
-                plan.isInstanceOf[HashAggregateExecTransformer]
-              }) == 0)
-        }
+      checkGlutenOperatorMatch[HashAggregateExecTransformer]
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to