Repository: spark Updated Branches: refs/heads/master 388f5a063 -> 88e0c7bbd
[SPARK-24341][SQL] Support only IN subqueries with the same number of items per row ## What changes were proposed in this pull request? Using struct types in subqueries with the `IN` clause can generate invalid plans in `RewritePredicateSubquery`. Indeed, we are not handling clearly the cases when the outer value is a struct or the output of the inner subquery is a struct. The PR aims to make Spark's behavior the same as the one of the other RDBMS - namely Oracle and Postgres behavior were checked. So we consider valid only queries having the same number of fields in the outer value and in the subquery. This means that: - `(a, b) IN (select c, d from ...)` is a valid query; - `(a, b) IN (select (c, d) from ...)` throws an AnalysisException, as in the subquery we have only one field of type struct while in the outer value we have 2 fields; - `a IN (select (c, d) from ...)` - where `a` is a struct - is a valid query. ## How was this patch tested? Added UT Closes #21403 from mgaido91/SPARK-24313. Authored-by: Marco Gaido <marcogaid...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/88e0c7bb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/88e0c7bb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/88e0c7bb Branch: refs/heads/master Commit: 88e0c7bbd566240d182332299cf6695890a953ad Parents: 388f5a0 Author: Marco Gaido <marcogaid...@gmail.com> Authored: Tue Aug 7 15:43:41 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Aug 7 15:43:41 2018 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 19 +++- .../sql/catalyst/analysis/TypeCoercion.scala | 28 +---- .../apache/spark/sql/catalyst/dsl/package.scala | 8 +- .../sql/catalyst/expressions/Canonicalize.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 105 +++++++++++-------- .../sql/catalyst/expressions/subquery.scala | 4 +- .../sql/catalyst/optimizer/expressions.scala | 1 + .../spark/sql/catalyst/optimizer/subquery.scala | 20 ++-- .../spark/sql/catalyst/parser/AstBuilder.scala | 7 +- .../catalyst/analysis/AnalysisErrorSuite.scala | 7 +- .../analysis/ResolveSubquerySuite.scala | 5 +- .../catalyst/optimizer/OptimizeInSuite.scala | 15 +++ .../PullupCorrelatedPredicatesSuite.scala | 4 +- .../catalyst/parser/ExpressionParserSuite.scala | 14 ++- .../inputs/subquery/in-subquery/in-basic.sql | 14 +++ .../subquery/in-subquery/in-basic.sql.out | 70 +++++++++++++ .../negative-cases/subq-input-typecheck.sql.out | 20 ++-- 17 files changed, 240 insertions(+), 103 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b5016fd..d391a93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1433,11 +1433,26 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + case InSubquery(values, l @ ListQuery(_, _, exprId, _)) + if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, plans)((plan, exprs) => { ListQuery(plan, exprs, exprId, plan.output) }) - In(value, Seq(expr)) + val subqueryOutput = expr.plan.output + val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery]) + if (values.length != subqueryOutput.length) { + throw new AnalysisException( + s"""Cannot analyze ${resolvedIn.sql}. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length} + |#columns in right hand side: ${subqueryOutput.length} + |Left side columns: + |[${values.map(_.sql).mkString(", ")}] + |Right side columns: + |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin) + } + resolvedIn } } http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 7dd26b6..648aa9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -449,15 +449,6 @@ object TypeCoercion { * Analysis Exception will be raised at the type checking phase. */ case class InConversion(conf: SQLConf) extends TypeCoercionRule { - private def flattenExpr(expr: Expression): Seq[Expression] = { - expr match { - // Multi columns in IN clause is represented as a CreateNamedStruct. - // flatten the named struct to get the list of expressions. - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. @@ -465,11 +456,9 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) - if !i.resolved && flattenExpr(a).length == sub.output.length => - // LHS is the value expression of IN subquery. - val lhs = flattenExpr(a) - + case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _)) + if !i.resolved && lhs.length == sub.output.length => + // LHS is the value expressions of IN subquery. // RHS is the subquery output. val rhs = sub.output @@ -485,20 +474,13 @@ object TypeCoercion { case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() case (e, _) => e } - val castedLhs = lhs.zip(commonTypes).map { + val newLhs = lhs.zip(commonTypes).map { case (e, dt) if e.dataType != dt => Cast(e, dt) case (e, _) => e } - // Before constructing the In expression, wrap the multi values in LHS - // in a CreatedNamedStruct. - val newLhs = castedLhs match { - case Seq(lhs) => lhs - case _ => CreateStruct(castedLhs) - } - val newSub = Project(castedRhs, sub) - In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) + InSubquery(newLhs, ListQuery(newSub, children, exprId, newSub.output)) } else { i } http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2b582b5..d3ccd18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -88,7 +88,13 @@ package object dsl { def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) - def in(list: Expression*): Expression = In(expr, list) + def in(list: Expression*): Expression = list match { + case Seq(l: ListQuery) => expr match { + case c: CreateNamedStruct => InSubquery(c.valExprs, l) + case other => InSubquery(Seq(other), l) + } + case _ => In(expr, list) + } def like(other: Expression): Expression = Like(expr, other) def rlike(other: Expression): Expression = RLike(expr, other) http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 7541f52..fe6db8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -87,8 +87,6 @@ object Canonicalize { case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) // order the list in the In operator - // In subqueries contain only one element of type ListQuery. So checking that the length > 1 - // we are not reordering In subqueries. case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) case _ => e http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f4077f78..149bd79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,6 +138,66 @@ case class Not(child: Expression) override def sql: String = s"(NOT ${child.sql})" } +/** + * Evaluates to `true` if `values` are returned in `query`'s result set. + */ +case class InSubquery(values: Seq[Expression], query: ListQuery) + extends Predicate with Unevaluable { + + @transient lazy val value: Expression = if (values.length > 1) { + CreateNamedStruct(values.zipWithIndex.flatMap { + case (v: NamedExpression, _) => Seq(Literal(v.name), v) + case (v, idx) => Seq(Literal(s"_$idx"), v) + }) + } else { + values.head + } + + + override def checkInputDataTypes(): TypeCheckResult = { + val mismatchOpt = !DataType.equalsStructurally(query.dataType, value.dataType, + ignoreNullability = true) + if (mismatchOpt) { + if (values.length != query.childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${values.length}. + |#columns in right hand side: ${query.childOutputs.length}. + |Left side columns: + |[${values.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${query.childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = values.zip(query.childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") + case _ => None + } + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${values.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + } + } + + override def children: Seq[Expression] = values :+ query + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ($query)" + override def sql: String = s"(${value.sql} IN (${query.sql}))" +} + /** * Evaluates to `true` if `list` contains `value`. @@ -169,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) if (mismatchOpt.isDefined) { - list match { - case ListQuery(_, _, _, childOutputs) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != childOutputs.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${childOutputs.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) - } else { - val mismatchedColumns = valExprs.zip(childOutputs).flatMap { - case (l, r) if l.dataType != r.dataType => - Seq(s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})") - case _ => None - } - TypeCheckResult.TypeCheckFailure( - s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) - } - case _ => - TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + - s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") - } + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } @@ -307,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def sql: String = { - val childrenSQL = children.map(_.sql) - val valueSQL = childrenSQL.head - val listSQL = childrenSQL.tail.mkString(", ") + val valueSQL = value.sql + val listSQL = list.map(_.sql).mkString(", ") s"($valueSQL IN ($listSQL))" } } http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 6acc87a..fc1caed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper { def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { case _: Exists | Not(_: Exists) => false - case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case _: InSubquery | Not(_: InSubquery) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { - case In(_, Seq(_: ListQuery)) => true + case _: InSubquery => true case _ => false }.isDefined }.isDefined http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e7b4730..5629b72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -523,6 +523,7 @@ object NullPropagation extends Rule[LogicalPlan] { // If the value expression is NULL then transform the In expression to null literal. case In(Literal(null, _), _) => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index de89e17..e9b7a8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -42,13 +43,6 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def getValueExpression(e: Expression): Seq[Expression] = { - e match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - } - private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match { // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, // the produced join then becomes unresolved and break structural integrity. We should @@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) - case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) => + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Deduplicate conflicting attributes if any. dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => + case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: @@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case InSubquery(values, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() - val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val inConditions = values.zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) // Deduplicate conflicting attributes if any. newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions)) http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 732d762..7bc1f63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1103,6 +1103,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case not => Not(e) } + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + // Create the predicate. ctx.kind.getType match { case SqlBaseParser.BETWEEN => @@ -1111,7 +1116,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query))))) + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 0a5194a..9477884 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -534,7 +534,7 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val plan = Project( - Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()), + Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()), LocalRelation(a)) assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) } @@ -543,12 +543,13 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + val plan1 = Filter(Cast(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) assertAnalysisError(plan1, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + val plan2 = Filter( + Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) assertAnalysisError(plan2, "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 1bf8d76..74a8590 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project} /** @@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) + val expr = Filter( + InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6..a36083b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: Inset optimization disabled as " + "list expression contains attribute)") { val originalQuery = http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index 169b8737..8a5a551 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { .select('c) val outerQuery = testRelation - .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .where(InSubquery(Seq('a), ListQuery(correlatedSubquery))) .select('a).analyze assert(outerQuery.resolved) http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index c37b9f1..781fc1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -154,7 +154,19 @@ class ExpressionParserSuite extends PlanTest { test("in sub-query") { assertEqual( "a in (select b from c)", - In('a, Seq(ListQuery(table("c").select('b))))) + InSubquery(Seq('a), ListQuery(table("c").select('b)))) + + assertEqual( + "(a, b, c) in (select d, e, f from g)", + InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + + assertEqual( + "(a, b) in (select c from d)", + InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + + assertEqual( + "(a) in (select b from c)", + InSubquery(Seq('a), ListQuery(table("c").select('b)))) } test("like expressions") { http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql new file mode 100644 index 0000000..f4ffc20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql @@ -0,0 +1,14 @@ +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1); +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2); +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2); + +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b); +-- Invalid query, see SPARK-24341 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b); + +-- Aliasing is needed as a workaround for SPARK-24443 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b); +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b); http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out new file mode 100644 index 0000000..088db55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out @@ -0,0 +1,70 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view struct_tab as select struct(col1 as a, col2 as b) as record from + values (1, 1), (1, 2), (2, 1), (2, 2) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b) +-- !query 3 schema +struct<1:int> +-- !query 3 output + + + +-- !query 4 +select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN (listquery())). +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2 +#columns in right hand side: 1 +Left side columns: +[tab_a.`a1`, tab_a.`b1`] +Right side columns: +[`named_struct(a2, a2, b2, b2)`]; + + +-- !query 5 +select count(*) from struct_tab where record in + (select (a2 as a, b2 as b) from tab_b) +-- !query 5 schema +struct<count(1):bigint> +-- !query 5 output +1 + + +-- !query 6 +select count(*) from struct_tab where record not in + (select (a2 as a, b2 as b) from tab_b) +-- !query 6 schema +struct<count(1):bigint> +-- !query 6 output +3 http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index dcd3005..c52e570 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -92,15 +92,15 @@ t1a IN (SELECT t2a, t2b struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 1. -#columns in right hand side: 2. +#columns in left hand side: 1 +#columns in right hand side: 2 Left side columns: -[t1.`t1a`]. +[t1.`t1a`] Right side columns: -[t2.`t2a`, t2.`t2b`].; +[t2.`t2a`, t2.`t2b`]; -- !query 8 @@ -113,15 +113,15 @@ WHERE struct<> -- !query 8 output org.apache.spark.sql.AnalysisException -cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`))). The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. -#columns in left hand side: 2. -#columns in right hand side: 1. +#columns in left hand side: 2 +#columns in right hand side: 1 Left side columns: -[t1.`t1a`, t1.`t1b`]. +[t1.`t1a`, t1.`t1b`] Right side columns: -[t2.`t2a`].; +[t2.`t2a`]; -- !query 9 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org