maropu commented on a change in pull request #28804:
URL: https://github.com/apache/spark/pull/28804#discussion_r466377228



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -680,6 +688,16 @@ case class HashAggregateExec(
 
   private def doProduceWithKeys(ctx: CodegenContext): String = {
     val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
+    if (skipPartialAggregate) {
+      avoidSpillInPartialAggregateTerm = ctx.
+        addMutableState(CodeGenerator.JAVA_BOOLEAN,
+          "avoidPartialAggregate",
+          term => s"$term = ${Utils.isTesting};")
+    }
+    val childrenConsumed = ctx.
+      addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed")
+    rowCountTerm = ctx.
+      addMutableState(CodeGenerator.JAVA_LONG, "rowCount")

Review comment:
       We need the two variables even if `skipPartialAggregate`=false?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
##########
@@ -142,52 +142,57 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
   }
 
   test("Aggregate metrics: track avg probe") {
-    // The executed plan looks like:
-    // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, 
count#71L])
-    // +- Exchange hashpartitioning(a#61, 5)
-    //    +- HashAggregate(keys=[a#61], functions=[partial_count(1)], 
output=[a#61, count#76L])
-    //       +- Exchange RoundRobinPartitioning(1)
-    //          +- LocalTableScan [a#61]
-    //
-    // Assume the execution plan with node id is:
-    // Wholestage disabled:
-    // HashAggregate(nodeId = 0)
-    //   Exchange(nodeId = 1)
-    //     HashAggregate(nodeId = 2)
-    //       Exchange (nodeId = 3)
-    //         LocalTableScan(nodeId = 4)
-    //
-    // Wholestage enabled:
-    // WholeStageCodegen(nodeId = 0)
-    //   HashAggregate(nodeId = 1)
-    //     Exchange(nodeId = 2)
-    //       WholeStageCodegen(nodeId = 3)
-    //         HashAggregate(nodeId = 4)
-    //           Exchange(nodeId = 5)
-    //             LocalTableScan(nodeId = 6)
-    Seq(true, false).foreach { enableWholeStage =>
-      val df = generateRandomBytesDF().repartition(1).groupBy('a).count()
-      val nodeIds = if (enableWholeStage) {
-        Set(4L, 1L)
-      } else {
-        Set(2L, 0L)
-      }
-      val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
-      nodeIds.foreach { nodeId =>
-        val probes = metrics(nodeId)._2("avg hash probe bucket list 
iters").toString
-        if (!probes.contains("\n")) {
-          // It's a single metrics value
-          assert(probes.toDouble > 1.0)
+    if 
(spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED)) {
+      logInfo("Skipping, since partial Aggregation is disabled")
+    } else {
+      // The executed plan looks like:
+      // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, 
count#71L])
+      // +- Exchange hashpartitioning(a#61, 5)
+      //    +- HashAggregate(keys=[a#61], functions=[partial_count(1)], 
output=[a#61, count#76L])
+      //       +- Exchange RoundRobinPartitioning(1)
+      //          +- LocalTableScan [a#61]
+      //
+      // Assume the execution plan with node id is:
+      // Wholestage disabled:
+      // HashAggregate(nodeId = 0)
+      //   Exchange(nodeId = 1)
+      //     HashAggregate(nodeId = 2)
+      //       Exchange (nodeId = 3)
+      //         LocalTableScan(nodeId = 4)
+      //
+      // Wholestage enabled:
+      // WholeStageCodegen(nodeId = 0)
+      //   HashAggregate(nodeId = 1)
+      //     Exchange(nodeId = 2)
+      //       WholeStageCodegen(nodeId = 3)
+      //         HashAggregate(nodeId = 4)
+      //           Exchange(nodeId = 5)
+      //             LocalTableScan(nodeId = 6)
+      Seq(true, false).foreach { enableWholeStage =>
+        val df = generateRandomBytesDF().repartition(1).groupBy('a).count()
+        val nodeIds = if (enableWholeStage) {
+          Set(4L, 1L)
         } else {
-          val mainValue = 
probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")")
-          // Extract min, med, max from the string and strip off everthing 
else.
-          val index = mainValue.indexOf(" (", 0)
-          mainValue.slice(0, index).split(", ").foreach {
-            probe => assert(probe.toDouble > 1.0)
+          Set(2L, 0L)
+        }
+        val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
+        nodeIds.foreach { nodeId =>
+          val probes = metrics(nodeId)._2("avg hash probe bucket list 
iters").toString
+          if (!probes.contains("\n")) {
+            // It's a single metrics value
+            assert(probes.toDouble > 1.0)
+          } else {
+            val mainValue = 
probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")")
+            // Extract min, med, max from the string and strip off everthing 
else.
+            val index = mainValue.indexOf(" (", 0)
+            mainValue.slice(0, index).split(", ").foreach {
+              probe => assert(probe.toDouble > 1.0)
+            }
           }
         }
       }
     }
+

Review comment:
       nit: remove this blank.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
##########
@@ -51,6 +51,52 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
     assert(df.collect() === Array(Row(9, 4.5)))
   }
 
+  test(s"Avoid spill in partial aggregation" ) {

Review comment:
       drop `s` in the head.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -877,50 +903,111 @@ case class HashAggregateExec(
       ("true", "true", "", "")
     }
 
+    val skipPartialAggregateThreshold = 
sqlContext.conf.skipPartialAggregateThreshold
+    val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio
+
+    val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
     val oomeClassName = classOf[SparkOutOfMemoryError].getName
+    val findOrInsertRegularHashMap: String = {
+      def getAggBufferFromMap = {
+        s"""
+           |// generate grouping key
+           |${unsafeRowKeyCode.code}
+           |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
+           |if ($checkFallbackForBytesToBytesMap) {
+           |  // try to get the buffer from hash map
+           |  $unsafeRowBuffer =
+           |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
+           |}
+          """.stripMargin
+      }
 
-    val findOrInsertRegularHashMap: String =
-      s"""
-         |// generate grouping key
-         |${unsafeRowKeyCode.code}
-         |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
-         |if ($checkFallbackForBytesToBytesMap) {
-         |  // try to get the buffer from hash map
-         |  $unsafeRowBuffer =
-         |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
-         |}
-         |// Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
-         |// aggregation after processing all input rows.
-         |if ($unsafeRowBuffer == null) {
-         |  if ($sorterTerm == null) {
-         |    $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
-         |  } else {
-         |    
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
-         |  }
-         |  $resetCounter
-         |  // the hash map had be spilled, it should have enough memory now,
-         |  // try to allocate buffer again.
-         |  $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
-         |    $unsafeRowKeys, $unsafeRowKeyHash);
-         |  if ($unsafeRowBuffer == null) {
-         |    // failed to allocate the first page
-         |    throw new $oomeClassName("No enough memory for aggregation");
-         |  }
-         |}
+      def addToSorter: String = {
+        s"""
+          |if ($sorterTerm == null) {
+          |  $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+          |} else {
+          |  $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+          |}
+          |$resetCounter
+          |// the hash map had be spilled, it should have enough memory now,
+          |// try to allocate buffer again.
+          |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
+          |  $unsafeRowKeys, $unsafeRowKeyHash);
+          |if ($unsafeRowBuffer == null) {
+          |  // failed to allocate the first page
+          |  throw new $oomeClassName("No enough memory for aggregation");
+          |}""".stripMargin
+      }
+
+      def getHeuristicToAvoidAgg: String = {
+        s"""
+          |!($rowCountTerm < $skipPartialAggregateThreshold) &&
+          |      ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio;
+          |""".stripMargin
+      }
+
+      if (skipPartialAggregate) {
+        s"""

Review comment:
       How about this?
   ```
         if (skipPartialAggregate) {
           val checkIfPartialAggSkipped = "!($rowCountTerm < 
$skipPartialAggregateThreshold) && " +
             "((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio"
           s"""
              |if (!$avoidSpillInPartialAggregateTerm) {
              |  $getAggBufferFromMap
              |  ...
              |  if ($unsafeRowBuffer == null && 
!$avoidSpillInPartialAggregateTerm) {
              |    $countTerm = $countTerm + $hashMapTerm.getNumRows();
              |    if (checkIfPartialAggSkipped) {
   ```

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -680,6 +688,16 @@ case class HashAggregateExec(
 
   private def doProduceWithKeys(ctx: CodegenContext): String = {
     val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
+    if (skipPartialAggregate) {
+      avoidSpillInPartialAggregateTerm = ctx.
+        addMutableState(CodeGenerator.JAVA_BOOLEAN,
+          "avoidPartialAggregate",
+          term => s"$term = ${Utils.isTesting};")

Review comment:
       `${Utils.isTesting}`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -929,6 +1016,27 @@ case class HashAggregateExec(
       } else {
         findOrInsertRegularHashMap
       }
+      def createEmptyAggBufferAndUpdateMetrics: String = {
+        if (skipPartialAggregate) {
+          val partTerm = metricTerm(ctx, "partialAggSkipped")

Review comment:
       `partTerm`->`numAggSkippedRows `?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -72,6 +72,8 @@ case class HashAggregateExec(
     "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
     "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in 
aggregation build"),
+    "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext,
+      "number of skipped records for partial aggregates"),

Review comment:
       I think this metric is only meaningful for aggregates with a partial 
mode, so could we show it only in the mode?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -409,6 +411,12 @@ case class HashAggregateExec(
   private var fastHashMapTerm: String = _
   private var isFastHashMapEnabled: Boolean = false
 
+  private var avoidSpillInPartialAggregateTerm: String = _
+  private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate &&

Review comment:
       `skipPartialAggregate` -> `skipPartialAggregateEnabled`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -750,18 +768,19 @@ case class HashAggregateExec(
       finishRegularHashMap
     }
 
+    outputFunc = generateResultFunction(ctx)
     val doAggFuncName = ctx.addNewFunction(doAgg,
       s"""
          |private void $doAgg() throws java.io.IOException {
          |  ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+         |  $childrenConsumed = true;

Review comment:
       Why do we need this variable?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -750,18 +768,19 @@ case class HashAggregateExec(
       finishRegularHashMap
     }
 
+    outputFunc = generateResultFunction(ctx)
     val doAggFuncName = ctx.addNewFunction(doAgg,
       s"""
          |private void $doAgg() throws java.io.IOException {
          |  ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+         |  $childrenConsumed = true;
          |  $finishHashMap
          |}
        """.stripMargin)
 
     // generate code for output
     val keyTerm = ctx.freshName("aggKey")
     val bufferTerm = ctx.freshName("aggBuffer")
-    val outputFunc = generateResultFunction(ctx)

Review comment:
       Why did you move this line into the line 771?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -409,6 +411,12 @@ case class HashAggregateExec(
   private var fastHashMapTerm: String = _
   private var isFastHashMapEnabled: Boolean = false
 
+  private var avoidSpillInPartialAggregateTerm: String = _
+  private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate &&
+    AggUtils.areAggExpressionsPartial(modes) && 
find(_.isInstanceOf[ExpandExec]).isEmpty

Review comment:
       Why do we need this check `find(_.isInstanceOf[ExpandExec]).isEmpty`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
##########
@@ -353,4 +353,8 @@ object AggUtils {
 
     finalAndCompleteAggregate :: Nil
   }
+
+  def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = {

Review comment:
       Could you inline this method in `HashAggregateExec`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -680,6 +688,16 @@ case class HashAggregateExec(
 
   private def doProduceWithKeys(ctx: CodegenContext): String = {

Review comment:
       Why did you apply this optimization only for the with-key case?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -838,13 +857,20 @@ case class HashAggregateExec(
        |  long $beforeAgg = System.nanoTime();
        |  $doAggFuncName();
        |  $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
+       |  if (shouldStop()) return;
+       |}
+       |if (!$childrenConsumed) {
+       |  $doAggFuncName();
+       |  if (shouldStop()) return;

Review comment:
       What does this block mean? Looks like the method call in the line 858 
sets true to `childrenConsumed ` then this block not executed?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -877,50 +903,111 @@ case class HashAggregateExec(
       ("true", "true", "", "")
     }
 
+    val skipPartialAggregateThreshold = 
sqlContext.conf.skipPartialAggregateThreshold
+    val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio
+
+    val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
     val oomeClassName = classOf[SparkOutOfMemoryError].getName
+    val findOrInsertRegularHashMap: String = {
+      def getAggBufferFromMap = {
+        s"""
+           |// generate grouping key
+           |${unsafeRowKeyCode.code}
+           |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
+           |if ($checkFallbackForBytesToBytesMap) {
+           |  // try to get the buffer from hash map
+           |  $unsafeRowBuffer =
+           |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
+           |}
+          """.stripMargin
+      }
 
-    val findOrInsertRegularHashMap: String =
-      s"""
-         |// generate grouping key
-         |${unsafeRowKeyCode.code}
-         |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
-         |if ($checkFallbackForBytesToBytesMap) {
-         |  // try to get the buffer from hash map
-         |  $unsafeRowBuffer =
-         |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
-         |}
-         |// Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
-         |// aggregation after processing all input rows.
-         |if ($unsafeRowBuffer == null) {
-         |  if ($sorterTerm == null) {
-         |    $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
-         |  } else {
-         |    
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
-         |  }
-         |  $resetCounter
-         |  // the hash map had be spilled, it should have enough memory now,
-         |  // try to allocate buffer again.
-         |  $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
-         |    $unsafeRowKeys, $unsafeRowKeyHash);
-         |  if ($unsafeRowBuffer == null) {
-         |    // failed to allocate the first page
-         |    throw new $oomeClassName("No enough memory for aggregation");
-         |  }
-         |}
+      def addToSorter: String = {
+        s"""
+          |if ($sorterTerm == null) {
+          |  $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+          |} else {
+          |  $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+          |}
+          |$resetCounter
+          |// the hash map had be spilled, it should have enough memory now,
+          |// try to allocate buffer again.
+          |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
+          |  $unsafeRowKeys, $unsafeRowKeyHash);
+          |if ($unsafeRowBuffer == null) {
+          |  // failed to allocate the first page
+          |  throw new $oomeClassName("No enough memory for aggregation");
+          |}""".stripMargin
+      }
+
+      def getHeuristicToAvoidAgg: String = {
+        s"""
+          |!($rowCountTerm < $skipPartialAggregateThreshold) &&
+          |      ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio;
+          |""".stripMargin
+      }
+
+      if (skipPartialAggregate) {
+        s"""
+           |if (!$avoidSpillInPartialAggregateTerm) {
+           |  $getAggBufferFromMap
+           |  // Can't allocate buffer from the hash map.
+           |  // Check if we can avoid partial aggregation.
+           |  //  Otherwise, Spill the map and fallback to sort-based
+           |  // aggregation after processing all input rows.
+           |  if ($unsafeRowBuffer == null && 
!$avoidSpillInPartialAggregateTerm) {
+           |    $countTerm = $countTerm + $hashMapTerm.getNumRows();
+           |    boolean skipPartAgg = $getHeuristicToAvoidAgg
+           |    if (skipPartAgg) {
+           |      // Aggregation buffer is created later
+           |      $avoidSpillInPartialAggregateTerm = true;
+           |    } else {
+           |      $addToSorter
+           |    }
+           |  }
+           |}
+       """.stripMargin
+      } else {
+        s"""
+           |  $getAggBufferFromMap
+           |  // Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
+           |  // aggregation after processing all input rows.
+           |  if ($unsafeRowBuffer == null) {
+           |    $addToSorter
+           |  }

Review comment:
       nit: please remove the indents.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -72,6 +72,8 @@ case class HashAggregateExec(
     "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
     "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in 
aggregation build"),
+    "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext,
+      "number of skipped records for partial aggregates"),

Review comment:
       `partialAggSkipped` -> `numAggSkippedRows`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -680,6 +688,16 @@ case class HashAggregateExec(
 
   private def doProduceWithKeys(ctx: CodegenContext): String = {
     val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
+    if (skipPartialAggregate) {
+      avoidSpillInPartialAggregateTerm = ctx.
+        addMutableState(CodeGenerator.JAVA_BOOLEAN,
+          "avoidPartialAggregate",
+          term => s"$term = ${Utils.isTesting};")
+    }
+    val childrenConsumed = ctx.
+      addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed")
+    rowCountTerm = ctx.
+      addMutableState(CodeGenerator.JAVA_LONG, "rowCount")

Review comment:
       Also, I think we need to initialize the variables.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -877,50 +903,111 @@ case class HashAggregateExec(
       ("true", "true", "", "")
     }
 
+    val skipPartialAggregateThreshold = 
sqlContext.conf.skipPartialAggregateThreshold
+    val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio
+
+    val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
     val oomeClassName = classOf[SparkOutOfMemoryError].getName
+    val findOrInsertRegularHashMap: String = {
+      def getAggBufferFromMap = {
+        s"""
+           |// generate grouping key
+           |${unsafeRowKeyCode.code}
+           |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
+           |if ($checkFallbackForBytesToBytesMap) {
+           |  // try to get the buffer from hash map
+           |  $unsafeRowBuffer =
+           |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
+           |}
+          """.stripMargin
+      }
 
-    val findOrInsertRegularHashMap: String =
-      s"""
-         |// generate grouping key
-         |${unsafeRowKeyCode.code}
-         |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
-         |if ($checkFallbackForBytesToBytesMap) {
-         |  // try to get the buffer from hash map
-         |  $unsafeRowBuffer =
-         |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
$unsafeRowKeyHash);
-         |}
-         |// Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
-         |// aggregation after processing all input rows.
-         |if ($unsafeRowBuffer == null) {
-         |  if ($sorterTerm == null) {
-         |    $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
-         |  } else {
-         |    
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
-         |  }
-         |  $resetCounter
-         |  // the hash map had be spilled, it should have enough memory now,
-         |  // try to allocate buffer again.
-         |  $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
-         |    $unsafeRowKeys, $unsafeRowKeyHash);
-         |  if ($unsafeRowBuffer == null) {
-         |    // failed to allocate the first page
-         |    throw new $oomeClassName("No enough memory for aggregation");
-         |  }
-         |}
+      def addToSorter: String = {
+        s"""
+          |if ($sorterTerm == null) {
+          |  $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+          |} else {
+          |  $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+          |}
+          |$resetCounter
+          |// the hash map had be spilled, it should have enough memory now,
+          |// try to allocate buffer again.
+          |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
+          |  $unsafeRowKeys, $unsafeRowKeyHash);
+          |if ($unsafeRowBuffer == null) {
+          |  // failed to allocate the first page
+          |  throw new $oomeClassName("No enough memory for aggregation");
+          |}""".stripMargin
+      }
+
+      def getHeuristicToAvoidAgg: String = {
+        s"""
+          |!($rowCountTerm < $skipPartialAggregateThreshold) &&
+          |      ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio;
+          |""".stripMargin
+      }
+
+      if (skipPartialAggregate) {
+        s"""
+           |if (!$avoidSpillInPartialAggregateTerm) {
+           |  $getAggBufferFromMap
+           |  // Can't allocate buffer from the hash map.
+           |  // Check if we can avoid partial aggregation.
+           |  //  Otherwise, Spill the map and fallback to sort-based
+           |  // aggregation after processing all input rows.
+           |  if ($unsafeRowBuffer == null && 
!$avoidSpillInPartialAggregateTerm) {

Review comment:
       The second condition `!$avoidSpillInPartialAggregateTerm` has already 
been checked in the line 952? 
https://github.com/apache/spark/pull/28804/files#diff-2eb948516b5beaeb746aadac27fbd5b4R952

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -2196,6 +2196,29 @@ object SQLConf {
       .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in 
[10, 30].")
       .createWithDefault(16)
 
+  val SKIP_PARTIAL_AGGREGATE_ENABLED =
+    buildConf("spark.sql.aggregate.partialaggregate.skip.enabled")
+      .internal()
+      .doc("Avoid sorter(sort/spill) during partial aggregation")
+      .booleanConf
+      .createWithDefault(true)
+
+  val SKIP_PARTIAL_AGGREGATE_THRESHOLD =
+    buildConf("spark.sql.aggregate.partialaggregate.skip.threshold")

Review comment:
       I think we need to assign more meaningful param names, e.g.,
   ```
   
   spark.sql.aggregate.skipPartialAggregate
   spark.sql.aggregate.skipPartialAggregate.minNumRows
   spark.sql.aggregate.skipPartialAggregate.aggregateRatio
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -2196,6 +2196,29 @@ object SQLConf {
       .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in 
[10, 30].")
       .createWithDefault(16)
 
+  val SKIP_PARTIAL_AGGREGATE_ENABLED =
+    buildConf("spark.sql.aggregate.partialaggregate.skip.enabled")
+      .internal()
+      .doc("Avoid sorter(sort/spill) during partial aggregation")

Review comment:
       Could you describe more in this doc? For example, this PR intends to 
only support hash-based aggregate with codegen enabled?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to