gengliangwang commented on code in PR #35896:
URL: https://github.com/apache/spark/pull/35896#discussion_r848179426


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala:
##########
@@ -109,17 +96,107 @@ case class Average(
       Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
   }
 
-  override lazy val updateExpressions: Seq[Expression] = Seq(
+  protected def getUpdateExpressions: Seq[Expression] = Seq(
     /* sum = */
     Add(
       sum,
-      coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
+      coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
+      failOnError = failOnError),
     /* count = */ If(child.isNull, count, count + 1L)
   )
 
+  // The flag `failOnError` won't be shown in the `toString` or `toAggString` 
methods
+  override def flatArguments: Iterator[Any] = Iterator(child)
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+       2.0
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+       1.5
+  """,
+  group = "agg_funcs",
+  since = "1.0.0")
+case class Average(
+    child: Expression,
+    failOnError: Boolean = SQLConf.get.ansiEnabled) extends AverageBase {
+  def this(child: Expression) = this(child, failOnError = 
SQLConf.get.ansiEnabled)
+
   override protected def withNewChildInternal(newChild: Expression): Average =
     copy(child = newChild)
 
-  // The flag `failOnError` won't be shown in the `toString` or `toAggString` 
methods
-  override def flatArguments: Iterator[Any] = Iterator(child)
+  override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
+
+  override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
+
+  override lazy val evaluateExpression: Expression = getEvaluateExpression
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group 
and the result is null on overflow.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col);
+       2.0
+      > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
+       1.5
+      > SELECT _FUNC_(col) FROM VALUES (interval '2147483647 months'), 
(interval '1 months') AS tab(col);
+       NULL
+  """,
+  group = "agg_funcs",
+  since = "3.3.0")
+// scalastyle:on line.size.limit
+case class TryAverage(child: Expression) extends AverageBase {
+  override def failOnError: Boolean = resultType match {
+    // Double type won't fail, thus the failOnError is always false
+    // For decimal type, it returns NULL on overflow. It behaves the same as 
TrySum when
+    // `failOnError` is false.
+    case _: DoubleType | _: DecimalType => false
+    case _ => true
+  }
+
+  private def addTryEvalIfNeeded(expression: Expression): Expression = {
+    if (failOnError) {
+      // The tail expressions are for counting, which doesn't need `TryEval` 
execution.

Review Comment:
   Removed, thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to