This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7c3c7c5a4bd [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics 7c3c7c5a4bd is described below commit 7c3c7c5a4bd94c9e05b5e680a5242c2485875633 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Sep 22 11:07:25 2023 +0800 [SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics ### What changes were proposed in this pull request? In existing code, plan matching is used to validate if two CollectMetrics have the same name but different semantic. However, plan matching approach is fragile. A better way to tackle this is to just utilize the unique DataFrame Id. This is because observe API is only supported by DataFrame API. SQL does not have such syntax. So two CollectMetric are semantic the same if and only if they have same name and same DataFrame id. ### Why are the changes needed? This is to use a more stable approach to replace a fragile approach. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #43010 from amaliujia/another_approch_for_collect_metrics. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 6 +-- python/pyspark/sql/connect/plan.py | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 36 ++------------ .../plans/logical/basicLogicalOperators.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 55 +++++++++------------- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- 8 files changed, 35 insertions(+), 74 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 924169715f7..dda7a713fa0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => transformCachedRemoteRelation(rel.getCachedRemoteRelation) case proto.Relation.RelTypeCase.COLLECT_METRICS => - transformCollectMetrics(rel.getCollectMetrics) + transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -1048,12 +1048,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { numPartitionsOpt) } - private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan = { + private def transformCollectMetrics(rel: proto.CollectMetrics, planId: Long): LogicalPlan = { val metrics = rel.getMetricsList.asScala.toSeq.map { expr => Column(transformExpression(expr)) } - CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput)) + CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput), planId) } private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index d069081e1af..219545cf646 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1192,6 +1192,7 @@ class CollectMetrics(LogicalPlan): assert self._child is not None plan = proto.Relation() + plan.common.plan_id = self._child._plan_id plan.collect_metrics.input.CopyFrom(self._child.plan(session)) plan.collect_metrics.name = self._name plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cff29de858e..aac85e19721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3893,9 +3893,9 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) - case CollectMetrics(name, metrics, child) => + case CollectMetrics(name, metrics, child, dataframeId) => val cleanedMetrics = metrics.map(trimNonTopLevelAliases) - CollectMetrics(name, cleanedMetrics, child) + CollectMetrics(name, cleanedMetrics, child, dataframeId) case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, child) => val cleanedIds = ids.map(_.map(trimNonTopLevelAliases)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 3c9a816df26..83b682bc917 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -497,7 +497,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB groupingExprs.foreach(checkValidGroupingExprs) aggregateExprs.foreach(checkValidAggregateExpression) - case CollectMetrics(name, metrics, _) => + case CollectMetrics(name, metrics, _, _) => if (name == null || name.isEmpty) { operator.failAnalysis( errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", @@ -1097,17 +1097,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB * are allowed (e.g. self-joins). */ private def checkCollectedMetrics(plan: LogicalPlan): Unit = { - val metricsMap = mutable.Map.empty[String, LogicalPlan] + val metricsMap = mutable.Map.empty[String, CollectMetrics] def check(plan: LogicalPlan): Unit = plan.foreach { node => node match { - case metrics @ CollectMetrics(name, _, _) => - val simplifiedMetrics = simplifyPlanForCollectedMetrics(metrics.canonicalized) + case metrics @ CollectMetrics(name, _, _, dataframeId) => metricsMap.get(name) match { case Some(other) => - val simplifiedOther = simplifyPlanForCollectedMetrics(other.canonicalized) // Exact duplicates are allowed. They can be the result // of a CTE that is used multiple times or a self join. - if (simplifiedMetrics != simplifiedOther) { + if (dataframeId != other.dataframeId) { failAnalysis( errorClass = "DUPLICATED_METRICS_NAME", messageParameters = Map("metricName" -> name)) @@ -1126,32 +1124,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB check(plan) } - /** - * This method is only used for checking collected metrics. This method tries to - * remove extra project which only re-assign expr ids from the plan so that we can identify exact - * duplicates metric definition. - */ - def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperators { - case p: Project if p.projectList.size == p.child.output.size => - val assignExprIdOnly = p.projectList.zipWithIndex.forall { - case (Alias(attr: AttributeReference, _), index) => - // The input plan of this method is already canonicalized. The attribute id becomes the - // ordinal of this attribute in the child outputs. So an alias-only Project means the - // the id of the aliased attribute is the same as its index in the project list. - attr.exprId.id == index - case (left: AttributeReference, index) => - left.exprId.id == index - case _ => false - } - if (assignExprIdOnly) { - p.child - } else { - p - } - } - } - /** * Validates to make sure the outer references appearing inside the subquery * are allowed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index efb7dbb44ef..8f976a49a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1969,7 +1969,8 @@ trait SupportsSubquery extends LogicalPlan case class CollectMetrics( name: String, metrics: Seq[NamedExpression], - child: LogicalPlan) + child: LogicalPlan, + dataframeId: Long) extends UnaryNode { override lazy val resolved: Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ed3137430df..ffc12a2b981 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -779,34 +779,35 @@ class AnalysisSuite extends AnalysisTest with Matchers { val literal = Literal(1).as("lit") // Ok - assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation).resolved) + assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil, testRelation, 0).resolved) // Bad name - assert(!CollectMetrics("", sum :: Nil, testRelation).resolved) + assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved) assertAnalysisErrorClass( - CollectMetrics("", sum :: Nil, testRelation), + CollectMetrics("", sum :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME", expectedMessageParameters = Map( - "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+- LocalRelation <empty>, [a#x]\n") + "operator" -> + "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation <empty>, [a#x]\n") ) // No columns - assert(!CollectMetrics("evt", Nil, testRelation).resolved) + assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved) def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { - assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors) + assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors) } // Unwrapped attribute assertAnalysisErrorClass( - CollectMetrics("event", a :: Nil, testRelation), + CollectMetrics("event", a :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", expectedMessageParameters = Map("expr" -> "\"a\"") ) // Unwrapped non-deterministic expression assertAnalysisErrorClass( - CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation), + CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"") ) @@ -816,7 +817,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"") @@ -827,7 +828,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics( "event", Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum") :: Nil, - testRelation), + testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"") ) @@ -838,7 +839,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { WindowSpecDefinition(Nil, a.asc :: Nil, SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) assertAnalysisErrorClass( - CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation), + CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", expectedMessageParameters = Map( "expr" -> @@ -856,14 +857,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { // Same result - duplicate names are allowed assertAnalysisSuccess(Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, testRelation) :: Nil)) + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil)) // Same children, structurally different metrics - fail assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) @@ -873,17 +874,17 @@ class AnalysisSuite extends AnalysisTest with Matchers { val tblB = LocalRelation(b) assertAnalysisErrorClass( Union( - CollectMetrics("evt1", count :: Nil, testRelation) :: - CollectMetrics("evt1", count :: Nil, tblB) :: Nil), + CollectMetrics("evt1", count :: Nil, testRelation, 0) :: + CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil), expectedErrorClass = "DUPLICATED_METRICS_NAME", expectedMessageParameters = Map("metricName" -> "evt1") ) // Subquery different tree - fail - val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation)) + val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count :: Nil, testRelation, 0)) val query = Project( b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, - CollectMetrics("evt1", count :: Nil, tblB)) + CollectMetrics("evt1", count :: Nil, tblB, 1)) assertAnalysisErrorClass( query, expectedErrorClass = "DUPLICATED_METRICS_NAME", @@ -895,7 +896,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { case a: AggregateExpression => a.copy(filter = Some(true)) }.asInstanceOf[NamedExpression] assertAnalysisErrorClass( - CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), + CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0), expectedErrorClass = "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true) AS sum\"") @@ -1675,18 +1676,4 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze) } } - - test("simplifyPlanForCollectedMetrics should handle non alias-only project case") { - val inner = Project( - Seq( - Alias(testRelation2.output(0), "a")(), - testRelation2.output(1), - Alias(testRelation2.output(2), "c")(), - testRelation2.output(3), - testRelation2.output(4) - ), - testRelation2) - val actualPlan = getAnalyzer.simplifyPlanForCollectedMetrics(inner.canonicalized) - assert(actualPlan == testRelation2.canonicalized) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 528904bb29a..f07496e6430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2218,7 +2218,7 @@ class Dataset[T] private[sql]( */ @varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan) + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 903565a6d59..d851eacd5ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -935,7 +935,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE TABLE") case _: MergeIntoTable => throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE") - case logical.CollectMetrics(name, metrics, child) => + case logical.CollectMetrics(name, metrics, child, _) => execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org