This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4547c9c90e3 [SPARK-39812][SQL] Simplify code which construct `AggregateExpression` with `toAggregateExpression` 4547c9c90e3 is described below commit 4547c9c90e3d35436afe89b10c794050ed8d04d7 Author: Jiaan Geng <belie...@163.com> AuthorDate: Sat Jul 23 15:05:14 2022 -0700 [SPARK-39812][SQL] Simplify code which construct `AggregateExpression` with `toAggregateExpression` ### What changes were proposed in this pull request? Currently, Spark provides the `toAggregateExpression` to simplify the code. But we can find many places still use `AggregateExpression.apply`. This PR will use `toAggregateExpression` replaces with `AggregateExpression.apply`. ### Why are the changes needed? Simplify code with `toAggregateExpression`. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implementation. ### How was this patch tested? N/A Closes #37224 from beliefer/SPARK-39812. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: huaxingao <huaxin_...@apple.com> --- .../main/scala/org/apache/spark/ml/stat/Summarizer.scala | 4 ++-- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../sql/catalyst/optimizer/InjectRuntimeFilter.scala | 6 +++--- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 +++--- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 7 ++----- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 16 +++++----------- .../expressions/aggregate/AggregateExpressionSuite.scala | 2 +- .../org/apache/spark/sql/expressions/Aggregator.scala | 7 +------ .../spark/sql/expressions/UserDefinedFunction.scala | 3 +-- .../scala/org/apache/spark/sql/expressions/udaf.scala | 12 ++---------- 10 files changed, 23 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 7fd99faf0c8..bf9d07338db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -256,7 +256,7 @@ private[ml] class SummaryBuilderImpl( mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - new Column(AggregateExpression(agg, mode = Complete, isDistinct = false)) + new Column(agg.toAggregateExpression()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7667b4fef71..cc79048b7c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2219,9 +2219,9 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( agg.prettyName, "IGNORE NULLS") } - AggregateExpression(aggFunc, Complete, u.isDistinct, u.filter) + aggFunc.toAggregateExpression(u.isDistinct, u.filter) } else { - AggregateExpression(agg, Complete, u.isDistinct, u.filter) + agg.toAggregateExpression(u.isDistinct, u.filter) } // This function is not an aggregate function, just return the resolved one. case other if u.isDistinct => @@ -2332,7 +2332,7 @@ class Analyzer(override val catalogManager: CatalogManager) aggFunc.name(), "IGNORE NULLS") } val aggregator = V2Aggregator(aggFunc, arguments) - AggregateExpression(aggregator, Complete, u.isDistinct, u.filter) + aggregator.toAggregateExpression(u.isDistinct, u.filter) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index baaf82c00db..236636ac7ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -82,8 +82,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } else { new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) } - val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) - val alias = Alias(aggExp, "bloomFilter")() + + val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")() val aggregate = ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan))) val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 78fb8b5de88..02a64a22ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2217,7 +2217,7 @@ object RewriteExceptAll extends Rule[LogicalPlan] { val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) val aggSumCol = - Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() + Alias(Sum(unionPlan.output.head.toAttribute).toAggregateExpression(), "sum")() val aggOutputColumns = left.output ++ Seq(aggSumCol) val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan) val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan) @@ -2284,9 +2284,9 @@ object RewriteIntersectAll extends Rule[LogicalPlan] { // Expressions to compute count and minimum of both the counts. val vCol1AggrExpr = - Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")() + Alias(Count(unionPlan.output(0)).toAggregateExpression(), "vcol1_count")() val vCol2AggrExpr = - Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")() + Alias(Count(unionPlan.output(1)).toAggregateExpression(), "vcol2_count")() val ifExpression = Alias(If( GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute), vCol2AggrExpr.toAttribute, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 22cde964125..4511b3038f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -330,11 +330,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression( - aggregate.First(operator.toAttribute, ignoreNulls = true), - mode = Complete, - isDistinct = false, - filter = Some(EqualTo(gid, regularGroupId))) + val result = aggregate.First(operator.toAttribute, ignoreNulls = true) + .toAggregateExpression(isDistinct = false, filter = Some(EqualTo(gid, regularGroupId))) // Some aggregate functions (COUNT) have the special property that they can return a // non-null result without any input. We need to make sure we return a result in this case. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4ac665f9d87..257922eb81c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ @@ -172,7 +172,7 @@ class AnalysisErrorSuite extends AnalysisTest { "distinct aggregate function in window", testRelation2.select( WindowExpression( - AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), + Count(UnresolvedAttribute("b")).toAggregateExpression(isDistinct = true), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, @@ -183,11 +183,8 @@ class AnalysisErrorSuite extends AnalysisTest { "window aggregate function with filter predicate", testRelation2.select( WindowExpression( - AggregateExpression( - Count(UnresolvedAttribute("b")), - Complete, - isDistinct = false, - filter = Some(UnresolvedAttribute("b") > 1)), + Count(UnresolvedAttribute("b")) + .toAggregateExpression(isDistinct = false, filter = Some(UnresolvedAttribute("b") > 1)), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, @@ -249,10 +246,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "nested aggregate functions", testRelation.groupBy($"a")( - AggregateExpression( - Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)), - Complete, - isDistinct = false)), + Max(Count(Literal(1)).toAggregateExpression()).toAggregateExpression()), "not allowed to use an aggregate function in the argument of another aggregate function." :: Nil ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala index 410b11eda50..6bab2e1b187 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -26,7 +26,7 @@ class AggregateExpressionSuite extends SparkFunSuite { test("test references from unresolved aggregate functions") { val x = UnresolvedAttribute("x") val y = UnresolvedAttribute("y") - val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references + val actual = Sum(Add(x, y)).toAggregateExpression().references val expected = AttributeSet(x :: y :: Nil) assert(expected == actual, s"Expected: $expected. Actual: $actual") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 62d04cf7f7c..c48c8bbe469 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression /** @@ -98,11 +97,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable { implicit val bEncoder = bufferEncoder implicit val cEncoder = outputEncoder - val expr = - AggregateExpression( - TypedAggregateExpression(this), - Complete, - isDistinct = false) + val expr = TypedAggregateExpression(this).toAggregateExpression() new TypedColumn[IN, OUT](expr, encoderFor[OUT]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03dc9abf081..a75384fb0f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -21,7 +21,6 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaAggregator import org.apache.spark.sql.types.DataType @@ -143,7 +142,7 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( @scala.annotation.varargs def apply(exprs: Column*): Column = { - Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete, isDistinct = false)) + Column(scalaAggregator(exprs.map(_.expr)).toAggregateExpression()) } // This is also used by udf.register(...) when it detects a UserDefinedAggregator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 8407b1419af..b387695ef23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ @@ -131,11 +130,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { */ @scala.annotation.varargs def apply(exprs: Column*): Column = { - val aggregateExpression = - AggregateExpression( - ScalaUDAF(exprs.map(_.expr), this), - Complete, - isDistinct = false) + val aggregateExpression = ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression() Column(aggregateExpression) } @@ -148,10 +143,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression( - ScalaUDAF(exprs.map(_.expr), this), - Complete, - isDistinct = true) + ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression(isDistinct = true) Column(aggregateExpression) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org