Repository: spark Updated Branches: refs/heads/master e428b3a95 -> c19c78577
[SQL] [MINOR] correct semanticEquals logic It's a follow up of https://github.com/apache/spark/pull/6173, for expressions like `Coalesce` that have a `Seq[Expression]`, when we do semantic equal check for it, we need to do semantic equal check for all of its children. Also we can just use `Seq[(Expression, NamedExpression)]` instead of `Map[Expression, NamedExpression]` as we only search it with `find`. chenghao-intel, I agree that we probably never knows `semanticEquals` in a general way, but I think we have done that in `TreeNode`, so we can use similar logic. Then we can handle something like `Coalesce(children: Seq[Expression])` correctly. Author: Wenchen Fan <[email protected]> Closes #6261 from cloud-fan/tmp and squashes the following commits: 4daef88 [Wenchen Fan] address comments dd8fbd9 [Wenchen Fan] correct semanticEquals Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c19c7857 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c19c7857 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c19c7857 Branch: refs/heads/master Commit: c19c78577a211eefe1112ebd4670a4ce7c3cc3be Parents: e428b3a Author: Wenchen Fan <[email protected]> Authored: Fri Jun 12 16:38:28 2015 +0800 Committer: Cheng Lian <[email protected]> Committed: Fri Jun 12 16:38:28 2015 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/Expression.scala | 13 +++++++++---- .../spark/sql/catalyst/planning/patterns.scala | 18 ++++++++---------- .../spark/sql/execution/GeneratedAggregate.scala | 14 +++++++------- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8c1e4d7..0b9f621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] { * cosmetically (i.e. capitalization of names in attributes may be different). */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (i1, i2) => i1 == i2 - } + checkSemantic(elements1, elements2) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a8..3b6f8bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = + val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) - }.toMap + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -160,17 +160,15 @@ object PartialAggregation { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } - namedGroupingExpressions - .find { case (k, v) => k semanticEquals trimmed } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + val partialComputation = namedGroupingExpressions.map(_._2) ++ + partialEvaluations.values.flatMap(_.partialEvaluations) - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index af37917..1c40a92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -214,18 +214,18 @@ case class GeneratedAggregate( }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) } - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) + case e: Expression => + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) @@ -265,7 +265,7 @@ case class GeneratedAggregate( val resultProjectionBuilder = newMutableProjection( resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow3 http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 14ecd4e..6898d58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) :: --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
