This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4547c9c90e3 [SPARK-39812][SQL] Simplify code which construct 
`AggregateExpression` with `toAggregateExpression`
4547c9c90e3 is described below

commit 4547c9c90e3d35436afe89b10c794050ed8d04d7
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Sat Jul 23 15:05:14 2022 -0700

    [SPARK-39812][SQL] Simplify code which construct `AggregateExpression` with 
`toAggregateExpression`
    
    ### What changes were proposed in this pull request?
    Currently, Spark provides the `toAggregateExpression` to simplify the code.
    But we can find many places still use `AggregateExpression.apply`.
    
    This PR will use `toAggregateExpression` replaces with 
`AggregateExpression.apply`.
    
    ### Why are the changes needed?
    Simplify code with `toAggregateExpression`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just change the inner implementation.
    
    ### How was this patch tested?
    N/A
    
    Closes #37224 from beliefer/SPARK-39812.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: huaxingao <huaxin_...@apple.com>
---
 .../main/scala/org/apache/spark/ml/stat/Summarizer.scala |  4 ++--
 .../apache/spark/sql/catalyst/analysis/Analyzer.scala    |  6 +++---
 .../sql/catalyst/optimizer/InjectRuntimeFilter.scala     |  6 +++---
 .../apache/spark/sql/catalyst/optimizer/Optimizer.scala  |  6 +++---
 .../catalyst/optimizer/RewriteDistinctAggregates.scala   |  7 ++-----
 .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 16 +++++-----------
 .../expressions/aggregate/AggregateExpressionSuite.scala |  2 +-
 .../org/apache/spark/sql/expressions/Aggregator.scala    |  7 +------
 .../spark/sql/expressions/UserDefinedFunction.scala      |  3 +--
 .../scala/org/apache/spark/sql/expressions/udaf.scala    | 12 ++----------
 10 files changed, 23 insertions(+), 46 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 7fd99faf0c8..bf9d07338db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
ImplicitCastInputTypes}
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, TypedImperativeAggregate}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
 import org.apache.spark.sql.catalyst.trees.BinaryLike
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.types._
@@ -256,7 +256,7 @@ private[ml] class SummaryBuilderImpl(
       mutableAggBufferOffset = 0,
       inputAggBufferOffset = 0)
 
-    new Column(AggregateExpression(agg, mode = Complete, isDistinct = false))
+    new Column(agg.toAggregateExpression())
   }
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7667b4fef71..cc79048b7c7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2219,9 +2219,9 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                 throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
                   agg.prettyName, "IGNORE NULLS")
             }
-            AggregateExpression(aggFunc, Complete, u.isDistinct, u.filter)
+            aggFunc.toAggregateExpression(u.isDistinct, u.filter)
           } else {
-            AggregateExpression(agg, Complete, u.isDistinct, u.filter)
+            agg.toAggregateExpression(u.isDistinct, u.filter)
           }
         // This function is not an aggregate function, just return the 
resolved one.
         case other if u.isDistinct =>
@@ -2332,7 +2332,7 @@ class Analyzer(override val catalogManager: 
CatalogManager)
           aggFunc.name(), "IGNORE NULLS")
       }
       val aggregator = V2Aggregator(aggFunc, arguments)
-      AggregateExpression(aggregator, Complete, u.isDistinct, u.filter)
+      aggregator.toAggregateExpression(u.isDistinct, u.filter)
     }
 
     /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index baaf82c00db..236636ac7ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
BloomFilterAggregate, Complete}
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, 
PhysicalOperation}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -82,8 +82,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with 
PredicateHelper with J
       } else {
         new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
       }
-    val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = 
false, None)
-    val alias = Alias(aggExp, "bloomFilter")()
+
+    val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
     val aggregate =
       ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), 
filterCreationSidePlan)))
     val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 78fb8b5de88..02a64a22ed4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2217,7 +2217,7 @@ object RewriteExceptAll extends Rule[LogicalPlan] {
       val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, 
right)
       val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan)
       val aggSumCol =
-        Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), 
Complete, false), "sum")()
+        Alias(Sum(unionPlan.output.head.toAttribute).toAggregateExpression(), 
"sum")()
       val aggOutputColumns = left.output ++ Seq(aggSumCol)
       val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan)
       val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, 
Literal(0L)), aggregatePlan)
@@ -2284,9 +2284,9 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
 
       // Expressions to compute count and minimum of both the counts.
       val vCol1AggrExpr =
-        Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, 
false), "vcol1_count")()
+        Alias(Count(unionPlan.output(0)).toAggregateExpression(), 
"vcol1_count")()
       val vCol2AggrExpr =
-        Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, 
false), "vcol2_count")()
+        Alias(Count(unionPlan.output(1)).toAggregateExpression(), 
"vcol2_count")()
       val ifExpression = Alias(If(
         GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute),
         vCol2AggrExpr.toAttribute,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 22cde964125..4511b3038f0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -330,11 +330,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] 
{
         val operator = Alias(e.copy(aggregateFunction = af, filter = 
filterOpt), e.sql)()
 
         // Select the result of the first aggregate in the last aggregate.
-        val result = AggregateExpression(
-          aggregate.First(operator.toAttribute, ignoreNulls = true),
-          mode = Complete,
-          isDistinct = false,
-          filter = Some(EqualTo(gid, regularGroupId)))
+        val result = aggregate.First(operator.toAttribute, ignoreNulls = true)
+          .toAggregateExpression(isDistinct = false, filter = 
Some(EqualTo(gid, regularGroupId)))
 
         // Some aggregate functions (COUNT) have the special property that 
they can return a
         // non-null result without any input. We need to make sure we return a 
result in this case.
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4ac665f9d87..257922eb81c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Count, Max}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -172,7 +172,7 @@ class AnalysisErrorSuite extends AnalysisTest {
     "distinct aggregate function in window",
     testRelation2.select(
       WindowExpression(
-        AggregateExpression(Count(UnresolvedAttribute("b")), Complete, 
isDistinct = true),
+        Count(UnresolvedAttribute("b")).toAggregateExpression(isDistinct = 
true),
         WindowSpecDefinition(
           UnresolvedAttribute("a") :: Nil,
           SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
@@ -183,11 +183,8 @@ class AnalysisErrorSuite extends AnalysisTest {
     "window aggregate function with filter predicate",
     testRelation2.select(
       WindowExpression(
-        AggregateExpression(
-          Count(UnresolvedAttribute("b")),
-          Complete,
-          isDistinct = false,
-          filter = Some(UnresolvedAttribute("b") > 1)),
+        Count(UnresolvedAttribute("b"))
+          .toAggregateExpression(isDistinct = false, filter = 
Some(UnresolvedAttribute("b") > 1)),
         WindowSpecDefinition(
           UnresolvedAttribute("a") :: Nil,
           SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
@@ -249,10 +246,7 @@ class AnalysisErrorSuite extends AnalysisTest {
   errorTest(
     "nested aggregate functions",
     testRelation.groupBy($"a")(
-      AggregateExpression(
-        Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = 
false)),
-        Complete,
-        isDistinct = false)),
+      Max(Count(Literal(1)).toAggregateExpression()).toAggregateExpression()),
     "not allowed to use an aggregate function in the argument of another 
aggregate function." :: Nil
   )
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
index 410b11eda50..6bab2e1b187 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
@@ -26,7 +26,7 @@ class AggregateExpressionSuite extends SparkFunSuite {
   test("test references from unresolved aggregate functions") {
     val x = UnresolvedAttribute("x")
     val y = UnresolvedAttribute("y")
-    val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, 
isDistinct = false).references
+    val actual = Sum(Add(x, y)).toAggregateExpression().references
     val expected = AttributeSet(x :: y :: Nil)
     assert(expected == actual, s"Expected: $expected. Actual: $actual")
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 62d04cf7f7c..c48c8bbe469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions
 
 import org.apache.spark.sql.{Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.encoderFor
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 
 /**
@@ -98,11 +97,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends 
Serializable {
     implicit val bEncoder = bufferEncoder
     implicit val cEncoder = outputEncoder
 
-    val expr =
-      AggregateExpression(
-        TypedAggregateExpression(this),
-        Complete,
-        isDistinct = false)
+    val expr = TypedAggregateExpression(this).toAggregateExpression()
 
     new TypedColumn[IN, OUT](expr, encoderFor[OUT])
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 03dc9abf081..a75384fb0f4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -21,7 +21,6 @@ import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.{Column, Encoder}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete}
 import org.apache.spark.sql.execution.aggregate.ScalaAggregator
 import org.apache.spark.sql.types.DataType
 
@@ -143,7 +142,7 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
 
   @scala.annotation.varargs
   def apply(exprs: Column*): Column = {
-    Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete, 
isDistinct = false))
+    Column(scalaAggregator(exprs.map(_.expr)).toAggregateExpression())
   }
 
   // This is also used by udf.register(...) when it detects a 
UserDefinedAggregator
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8407b1419af..b387695ef23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions
 
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.{Column, Row}
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete}
 import org.apache.spark.sql.execution.aggregate.ScalaUDAF
 import org.apache.spark.sql.types._
 
@@ -131,11 +130,7 @@ abstract class UserDefinedAggregateFunction extends 
Serializable {
    */
   @scala.annotation.varargs
   def apply(exprs: Column*): Column = {
-    val aggregateExpression =
-      AggregateExpression(
-        ScalaUDAF(exprs.map(_.expr), this),
-        Complete,
-        isDistinct = false)
+    val aggregateExpression = ScalaUDAF(exprs.map(_.expr), 
this).toAggregateExpression()
     Column(aggregateExpression)
   }
 
@@ -148,10 +143,7 @@ abstract class UserDefinedAggregateFunction extends 
Serializable {
   @scala.annotation.varargs
   def distinct(exprs: Column*): Column = {
     val aggregateExpression =
-      AggregateExpression(
-        ScalaUDAF(exprs.map(_.expr), this),
-        Complete,
-        isDistinct = true)
+      ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression(isDistinct = 
true)
     Column(aggregateExpression)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to