This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4547c9c90e3 [SPARK-39812][SQL] Simplify code which construct
`AggregateExpression` with `toAggregateExpression`
4547c9c90e3 is described below
commit 4547c9c90e3d35436afe89b10c794050ed8d04d7
Author: Jiaan Geng <[email protected]>
AuthorDate: Sat Jul 23 15:05:14 2022 -0700
[SPARK-39812][SQL] Simplify code which construct `AggregateExpression` with
`toAggregateExpression`
### What changes were proposed in this pull request?
Currently, Spark provides the `toAggregateExpression` to simplify the code.
But we can find many places still use `AggregateExpression.apply`.
This PR will use `toAggregateExpression` replaces with
`AggregateExpression.apply`.
### Why are the changes needed?
Simplify code with `toAggregateExpression`.
### Does this PR introduce _any_ user-facing change?
'No'.
Just change the inner implementation.
### How was this patch tested?
N/A
Closes #37224 from beliefer/SPARK-39812.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: huaxingao <[email protected]>
---
.../main/scala/org/apache/spark/ml/stat/Summarizer.scala | 4 ++--
.../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++---
.../sql/catalyst/optimizer/InjectRuntimeFilter.scala | 6 +++---
.../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 +++---
.../catalyst/optimizer/RewriteDistinctAggregates.scala | 7 ++-----
.../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 16 +++++-----------
.../expressions/aggregate/AggregateExpressionSuite.scala | 2 +-
.../org/apache/spark/sql/expressions/Aggregator.scala | 7 +------
.../spark/sql/expressions/UserDefinedFunction.scala | 3 +--
.../scala/org/apache/spark/sql/expressions/udaf.scala | 12 ++----------
10 files changed, 23 insertions(+), 46 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 7fd99faf0c8..bf9d07338db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, TypedImperativeAggregate}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
@@ -256,7 +256,7 @@ private[ml] class SummaryBuilderImpl(
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)
- new Column(AggregateExpression(agg, mode = Complete, isDistinct = false))
+ new Column(agg.toAggregateExpression())
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7667b4fef71..cc79048b7c7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2219,9 +2219,9 @@ class Analyzer(override val catalogManager:
CatalogManager)
throw
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
agg.prettyName, "IGNORE NULLS")
}
- AggregateExpression(aggFunc, Complete, u.isDistinct, u.filter)
+ aggFunc.toAggregateExpression(u.isDistinct, u.filter)
} else {
- AggregateExpression(agg, Complete, u.isDistinct, u.filter)
+ agg.toAggregateExpression(u.isDistinct, u.filter)
}
// This function is not an aggregate function, just return the
resolved one.
case other if u.isDistinct =>
@@ -2332,7 +2332,7 @@ class Analyzer(override val catalogManager:
CatalogManager)
aggFunc.name(), "IGNORE NULLS")
}
val aggregator = V2Aggregator(aggFunc, arguments)
- AggregateExpression(aggregator, Complete, u.isDistinct, u.filter)
+ aggregator.toAggregateExpression(u.isDistinct, u.filter)
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index baaf82c00db..236636ac7ea 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
BloomFilterAggregate, Complete}
+import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys,
PhysicalOperation}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -82,8 +82,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with
PredicateHelper with J
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)))
}
- val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct =
false, None)
- val alias = Alias(aggExp, "bloomFilter")()
+
+ val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias),
filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 78fb8b5de88..02a64a22ed4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2217,7 +2217,7 @@ object RewriteExceptAll extends Rule[LogicalPlan] {
val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output,
right)
val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan)
val aggSumCol =
- Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute),
Complete, false), "sum")()
+ Alias(Sum(unionPlan.output.head.toAttribute).toAggregateExpression(),
"sum")()
val aggOutputColumns = left.output ++ Seq(aggSumCol)
val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan)
val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute,
Literal(0L)), aggregatePlan)
@@ -2284,9 +2284,9 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
// Expressions to compute count and minimum of both the counts.
val vCol1AggrExpr =
- Alias(AggregateExpression(Count(unionPlan.output(0)), Complete,
false), "vcol1_count")()
+ Alias(Count(unionPlan.output(0)).toAggregateExpression(),
"vcol1_count")()
val vCol2AggrExpr =
- Alias(AggregateExpression(Count(unionPlan.output(1)), Complete,
false), "vcol2_count")()
+ Alias(Count(unionPlan.output(1)).toAggregateExpression(),
"vcol2_count")()
val ifExpression = Alias(If(
GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute),
vCol2AggrExpr.toAttribute,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 22cde964125..4511b3038f0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -330,11 +330,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan]
{
val operator = Alias(e.copy(aggregateFunction = af, filter =
filterOpt), e.sql)()
// Select the result of the first aggregate in the last aggregate.
- val result = AggregateExpression(
- aggregate.First(operator.toAttribute, ignoreNulls = true),
- mode = Complete,
- isDistinct = false,
- filter = Some(EqualTo(gid, regularGroupId)))
+ val result = aggregate.First(operator.toAttribute, ignoreNulls = true)
+ .toAggregateExpression(isDistinct = false, filter =
Some(EqualTo(gid, regularGroupId)))
// Some aggregate functions (COUNT) have the special property that
they can return a
// non-null result without any input. We need to make sure we return a
result in this case.
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4ac665f9d87..257922eb81c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, Count, Max}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -172,7 +172,7 @@ class AnalysisErrorSuite extends AnalysisTest {
"distinct aggregate function in window",
testRelation2.select(
WindowExpression(
- AggregateExpression(Count(UnresolvedAttribute("b")), Complete,
isDistinct = true),
+ Count(UnresolvedAttribute("b")).toAggregateExpression(isDistinct =
true),
WindowSpecDefinition(
UnresolvedAttribute("a") :: Nil,
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
@@ -183,11 +183,8 @@ class AnalysisErrorSuite extends AnalysisTest {
"window aggregate function with filter predicate",
testRelation2.select(
WindowExpression(
- AggregateExpression(
- Count(UnresolvedAttribute("b")),
- Complete,
- isDistinct = false,
- filter = Some(UnresolvedAttribute("b") > 1)),
+ Count(UnresolvedAttribute("b"))
+ .toAggregateExpression(isDistinct = false, filter =
Some(UnresolvedAttribute("b") > 1)),
WindowSpecDefinition(
UnresolvedAttribute("a") :: Nil,
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
@@ -249,10 +246,7 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"nested aggregate functions",
testRelation.groupBy($"a")(
- AggregateExpression(
- Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct =
false)),
- Complete,
- isDistinct = false)),
+ Max(Count(Literal(1)).toAggregateExpression()).toAggregateExpression()),
"not allowed to use an aggregate function in the argument of another
aggregate function." :: Nil
)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
index 410b11eda50..6bab2e1b187 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
@@ -26,7 +26,7 @@ class AggregateExpressionSuite extends SparkFunSuite {
test("test references from unresolved aggregate functions") {
val x = UnresolvedAttribute("x")
val y = UnresolvedAttribute("y")
- val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete,
isDistinct = false).references
+ val actual = Sum(Add(x, y)).toAggregateExpression().references
val expected = AttributeSet(x :: y :: Nil)
assert(expected == actual, s"Expected: $expected. Actual: $actual")
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 62d04cf7f7c..c48c8bbe469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions
import org.apache.spark.sql.{Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.encoderFor
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
/**
@@ -98,11 +97,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends
Serializable {
implicit val bEncoder = bufferEncoder
implicit val cEncoder = outputEncoder
- val expr =
- AggregateExpression(
- TypedAggregateExpression(this),
- Complete,
- isDistinct = false)
+ val expr = TypedAggregateExpression(this).toAggregateExpression()
new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 03dc9abf081..a75384fb0f4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -21,7 +21,6 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{Column, Encoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete}
import org.apache.spark.sql.execution.aggregate.ScalaAggregator
import org.apache.spark.sql.types.DataType
@@ -143,7 +142,7 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
- Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete,
isDistinct = false))
+ Column(scalaAggregator(exprs.map(_.expr)).toAggregateExpression())
}
// This is also used by udf.register(...) when it detects a
UserDefinedAggregator
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8407b1419af..b387695ef23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{Column, Row}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.types._
@@ -131,11 +130,7 @@ abstract class UserDefinedAggregateFunction extends
Serializable {
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
- val aggregateExpression =
- AggregateExpression(
- ScalaUDAF(exprs.map(_.expr), this),
- Complete,
- isDistinct = false)
+ val aggregateExpression = ScalaUDAF(exprs.map(_.expr),
this).toAggregateExpression()
Column(aggregateExpression)
}
@@ -148,10 +143,7 @@ abstract class UserDefinedAggregateFunction extends
Serializable {
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
- AggregateExpression(
- ScalaUDAF(exprs.map(_.expr), this),
- Complete,
- isDistinct = true)
+ ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression(isDistinct =
true)
Column(aggregateExpression)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]