Github user hvanhovell commented on a diff in the pull request:
https://github.com/apache/spark/pull/13155#discussion_r66682369
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
---
@@ -1695,16 +1696,205 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] {
}
/**
+ * Statically evaluate an expression containing zero or more
placeholders, given a set
+ * of bindings for placeholder values.
+ */
+ private def evalExpr(expr: Expression, bindings: Map[ExprId,
Option[Any]]) : Option[Any] = {
+ val rewrittenExpr = expr transform {
+ case r @ AttributeReference(_, dataType, _, _) =>
+ bindings(r.exprId) match {
+ case Some(v) => Literal.create(v, dataType)
+ case None => Literal.default(NullType)
+ }
+ }
+ Option(rewrittenExpr.eval())
+ }
+
+ /**
+ * Statically evaluate an expression containing one or more aggregates
on an empty input.
+ */
+ private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
+ // AggregateExpressions are Unevaluable, so we need to replace all
aggregates
+ // in the expression with the value they would return for zero input
tuples.
+ // Also replace attribute refs (for example, for grouping columns)
with NULL.
+ val rewrittenExpr = expr transform {
+ case a @ AggregateExpression(aggFunc, _, _, resultId) =>
+ aggFunc.defaultResult.getOrElse(Literal.default(NullType))
+
+ case AttributeReference(_, _, _, _) => Literal.default(NullType)
+ }
+ Option(rewrittenExpr.eval())
+ }
+
+ /**
+ * Statically evaluate a scalar subquery on an empty input.
+ *
+ * <b>WARNING:</b> This method only covers subqueries that pass the
checks under
+ * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the
checks in
+ * CheckAnalysis become less restrictive, this method will need to
change.
+ */
+ private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
+ // Inputs to this method will start with a chain of zero or more
SubqueryAlias
+ // and Project operators, followed by an optional Filter, followed by
an
+ // Aggregate. Traverse the operators recursively.
+ def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = {
+ lp match {
+ case SubqueryAlias(_, child) => evalPlan(child)
+ case Filter(condition, child) =>
+ val bindings = evalPlan(child)
+ if (bindings.isEmpty) bindings
+ else {
+ val exprResult = evalExpr(condition, bindings).getOrElse(false)
+ .asInstanceOf[Boolean]
+ if (exprResult) bindings else Map.empty
+ }
+
+ case Project(projectList, child) =>
+ val bindings = evalPlan(child)
+ if (bindings.isEmpty) {
+ bindings
+ } else {
+ projectList.map(ne => (ne.exprId, evalExpr(ne,
bindings))).toMap
+ }
+
+ case Aggregate(_, aggExprs, _) =>
+ // Some of the expressions under the Aggregate node are the join
columns
+ // for joining with the outer query block. Fill those
expressions in with
+ // nulls and statically evaluate the remainder.
+ aggExprs.map(ne => ne match {
+ case AttributeReference(_, _, _, _) => (ne.exprId, None)
+ case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId,
None)
+ case _ => (ne.exprId, evalAggOnZeroTups(ne))
+ }).toMap
+
+ case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
+ }
+ }
+
+ val resultMap = evalPlan(plan)
+
+ // By convention, the scalar subquery result is the leftmost field.
+ resultMap(plan.output.head.exprId)
+ }
+
+ /**
+ * Split the plan for a scalar subquery into the parts above the
innermost query block
+ * (first part of returned value), the HAVING clause of the innermost
query block
+ * (optional second part) and the parts below the HAVING CLAUSE (third
part).
+ */
+ private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan],
Option[Filter], Aggregate) = {
+ val topPart = ArrayBuffer.empty[LogicalPlan]
+ var bottomPart : LogicalPlan = plan
+ while (true) {
+ bottomPart match {
+ case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) =>
+ return (topPart, Option(havingPart),
aggPart.asInstanceOf[Aggregate])
+
+ case aggPart@Aggregate(_, _, _) =>
+ // No HAVING clause
+ return (topPart, None, aggPart)
+
+ case p@Project(_, child) =>
+ topPart += p
+ bottomPart = child
+
+ case s@SubqueryAlias(_, child) =>
+ topPart += s
+ bottomPart = child
+
+ case Filter(_, op@_) =>
+ sys.error(s"Correlated subquery has unexpected operator $op
below filter")
+
+ case op@_ => sys.error(s"Unexpected operator $op in correlated
subquery")
+ }
+ }
+
+ sys.error("This line should be unreachable")
+ }
+
+
+
+ // Name of generated column used in rewrite below
+ val ALWAYS_TRUE_COLNAME = "alwaysTrue"
+
+ /**
* Construct a new child plan by left joining the given subqueries to a
base plan.
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
- Project(
- currentChild.output :+ query.output.head,
- Join(currentChild, query, LeftOuter,
conditions.reduceOption(And)))
+ val origOutput = query.output.head
+
+ val resultWithZeroTups = evalSubqueryOnZeroTups(query)
+ if (resultWithZeroTups.isEmpty) {
+ // CASE 1: Subquery guaranteed not to have the COUNT bug
+ Project(
+ currentChild.output :+ origOutput,
+ Join(currentChild, query, LeftOuter,
conditions.reduceOption(And)))
+ } else {
+ // Subquery might have the COUNT bug. Add appropriate
corrections.
+ val (topPart, havingNode, aggNode) = splitSubquery(query)
+
+ // The next two cases add a leading column to the outer join
input to make it
+ // possible to distinguish between the case when no tuples join
and the case
+ // when the tuple that joins contains null values.
+ // The leading column always has the value TRUE.
+ val alwaysTrueExprId = NamedExpression.newExprId
+ val alwaysTrueExpr = Alias(Literal.TrueLiteral,
+ ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
+ val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
+ BooleanType)(exprId = alwaysTrueExprId)
+
+ val aggValRef = query.output.head
+
+ if (!havingNode.isDefined) {
+ // CASE 2: Subquery with no HAVING clause
+ Project(
+ currentChild.output :+
+ Alias(
+ If(IsNull(alwaysTrueRef),
+ Literal(resultWithZeroTups.get, origOutput.dataType),
+ aggValRef), origOutput.name)(exprId =
origOutput.exprId),
+ Join(currentChild,
+ Project(query.output :+ alwaysTrueExpr, query),
+ LeftOuter, conditions.reduceOption(And)))
+
+ } else {
+ // CASE 3: Subquery with HAVING clause. Pull the HAVING clause
above the join.
+ // Need to modify any operators below the join to pass through
all columns
+ // referenced in the HAVING clause.
+ var subqueryRoot : UnaryNode = aggNode
+ val havingInputs : Seq[NamedExpression] = aggNode.output
+
+ topPart.reverse.foreach(
+ _ match {
+ case Project(projList, _) =>
+ subqueryRoot = Project(projList ++ havingInputs,
subqueryRoot)
+ case s@SubqueryAlias(alias, _) => subqueryRoot =
SubqueryAlias(alias, subqueryRoot)
+ case op@_ => sys.error(s"Unexpected operator $op in
corelated subquery")
+ }
+ )
+
+ // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
+ // WHEN NOT (original HAVING clause expr) THEN CAST(null
AS <type of aggVal>)
+ // ELSE (aggregate value) END AS (original column name)
+ val caseExpr = Alias(CaseWhen(
+ Seq[(Expression, Expression)] (
--- End diff --
Do we need to type the Seq?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]