Github user dilipbiswal commented on a diff in the pull request:
https://github.com/apache/spark/pull/17713#discussion_r115136441
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
---
@@ -414,4 +352,269 @@ trait CheckAnalysis extends PredicateHelper {
plan.foreach(_.setAnalyzed())
}
+
+ /**
+ * Validates subquery expressions in the plan. Upon failure, returns an
user facing error.
+ */
+ private def checkSubqueryExpression(plan: LogicalPlan, expr:
SubqueryExpression): Unit = {
+ def checkAggregate(conditions: Seq[Expression], query: LogicalPlan,
agg: Aggregate): Unit = {
+ // Make sure correlated scalar subqueries contain one row for every
outer row by
+ // enforcing that they are aggregates containing exactly one
aggregate expression.
+ val aggregates = agg.expressions.flatMap(_.collect {
+ case a: AggregateExpression => a
+ })
+ if (aggregates.isEmpty) {
+ failAnalysis("The output of a correlated scalar subquery must be
aggregated")
+ }
+
+ // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
+ // are not part of the correlated columns.
+ val groupByCols =
AttributeSet(agg.groupingExpressions.flatMap(_.references))
+ // Collect the local references from the correlated predicate in the
subquery.
+ val subqueryColumns =
getCorrelatedPredicates(query).flatMap(_.references)
+ .filterNot(conditions.flatMap(_.references).contains)
+ val correlatedCols = AttributeSet(subqueryColumns)
+ val invalidCols = groupByCols -- correlatedCols
+ // GROUP BY columns must be a subset of columns in the predicates
+ if (invalidCols.nonEmpty) {
+ failAnalysis(
+ "A GROUP BY clause in a scalar correlated subquery " +
+ "cannot contain non-correlated columns: " +
+ invalidCols.mkString(","))
+ }
+ }
+
+ // Skip subquery aliases added by the Analyzer.
+ // For projects, do the necessary mapping and skip to its child.
+ def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
+ case s: SubqueryAlias => cleanQuery(s.child)
+ case p: Project => cleanQuery(p.child)
+ case child => child
+ }
+
+ // Validate to make sure the correlations appearing in the query are
valid and
+ // allowed by spark.
+ checkCorrelationsInSubquery(expr.plan)
+
+ expr match {
+ case ScalarSubquery(query, conditions, _) =>
+ // Scalar subquery must return one column as output.
+ if (query.output.size != 1) {
+ failAnalysis(
+ s"Scalar subquery must return only one column, but got
${query.output.size}")
+ }
+
+ if (conditions.nonEmpty) {
+ cleanQuery(query) match {
+ case a: Aggregate => checkAggregate(conditions, query, a)
+ case Filter(_, a: Aggregate) => checkAggregate(conditions,
query, a)
+ case fail => failAnalysis(s"Correlated scalar subqueries must
be Aggregated: $fail")
+ }
+
+ // Only certain operators are allowed to host subquery
expression containing
+ // outer references.
+ plan match {
+ case _: Filter | _: Aggregate | _: Project => // Ok
+ case other => failAnalysis(
+ s"Correlated scalar sub-queries can only be used in a " +
+ s"Filter/Aggregate/Project: $plan")
+ }
+ }
+
+ case inOrExistsSubquery =>
+ plan match {
+ case _: Filter => // Ok
+ case _ => failAnalysis(s"Predicate sub-queries can only be used
in a Filter: $plan")
+ }
+ }
+
+ // Validate the subquery plan.
+ checkAnalysis(expr.plan)
+ }
+
+ /**
+ * Validates to make sure the outer references appearing inside the
subquery
+ * are allowed.
+ */
+ private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = {
+ // Validate that correlated aggregate expression do not contain a
mixture
+ // of outer and local references.
+ def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
+ expr.foreach {
+ case a: AggregateExpression if containsOuter(a) =>
+ val outer = a.collect { case OuterReference(e) => e.toAttribute }
+ val local = a.references -- outer
+ if (local.nonEmpty) {
+ val msg =
+ s"""
+ |Found an aggregate expression in a correlated predicate
that has both
+ |outer and local references, which is not supported yet.
+ |Aggregate expression:
${SubExprUtils.stripOuterReference(a).sql},
+ |Outer references: ${outer.map(_.sql).mkString(", ")},
+ |Local references: ${local.map(_.sql).mkString(", ")}.
+ """.stripMargin.replace("\n", " ").trim()
--- End diff --
@gatorsmile fixed.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]