Repository: spark
Updated Branches:
  refs/heads/master 388f5a063 -> 88e0c7bbd


[SPARK-24341][SQL] Support only IN subqueries with the same number of items per 
row

## What changes were proposed in this pull request?

Using struct types in subqueries with the `IN` clause can generate invalid 
plans in `RewritePredicateSubquery`. Indeed, we are not handling clearly the 
cases when the outer value is a struct or the output of the inner subquery is a 
struct.

The PR aims to make Spark's behavior the same as the one of the other RDBMS - 
namely Oracle and Postgres behavior were checked. So we consider valid only 
queries having the same number of fields in the outer value and in the 
subquery. This means that:

 - `(a, b) IN (select c, d from ...)` is a valid query;
 - `(a, b) IN (select (c, d) from ...)` throws an AnalysisException, as in the 
subquery we have only one field of type struct while in the outer value we have 
2 fields;
 - `a IN (select (c, d) from ...)` - where `a` is a struct - is a valid query.

## How was this patch tested?

Added UT

Closes #21403 from mgaido91/SPARK-24313.

Authored-by: Marco Gaido <marcogaid...@gmail.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/88e0c7bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/88e0c7bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/88e0c7bb

Branch: refs/heads/master
Commit: 88e0c7bbd566240d182332299cf6695890a953ad
Parents: 388f5a0
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Tue Aug 7 15:43:41 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Aug 7 15:43:41 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  19 +++-
 .../sql/catalyst/analysis/TypeCoercion.scala    |  28 +----
 .../apache/spark/sql/catalyst/dsl/package.scala |   8 +-
 .../sql/catalyst/expressions/Canonicalize.scala |   2 -
 .../sql/catalyst/expressions/predicates.scala   | 105 +++++++++++--------
 .../sql/catalyst/expressions/subquery.scala     |   4 +-
 .../sql/catalyst/optimizer/expressions.scala    |   1 +
 .../spark/sql/catalyst/optimizer/subquery.scala |  20 ++--
 .../spark/sql/catalyst/parser/AstBuilder.scala  |   7 +-
 .../catalyst/analysis/AnalysisErrorSuite.scala  |   7 +-
 .../analysis/ResolveSubquerySuite.scala         |   5 +-
 .../catalyst/optimizer/OptimizeInSuite.scala    |  15 +++
 .../PullupCorrelatedPredicatesSuite.scala       |   4 +-
 .../catalyst/parser/ExpressionParserSuite.scala |  14 ++-
 .../inputs/subquery/in-subquery/in-basic.sql    |  14 +++
 .../subquery/in-subquery/in-basic.sql.out       |  70 +++++++++++++
 .../negative-cases/subq-input-typecheck.sql.out |  20 ++--
 17 files changed, 240 insertions(+), 103 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b5016fd..d391a93 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1433,11 +1433,26 @@ class Analyzer(
           resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, _, exprId) if !sub.resolved =>
           resolveSubQuery(e, plans)(Exists(_, _, exprId))
-        case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if 
value.resolved && !l.resolved =>
+        case InSubquery(values, l @ ListQuery(_, _, exprId, _))
+            if values.forall(_.resolved) && !l.resolved =>
           val expr = resolveSubQuery(l, plans)((plan, exprs) => {
             ListQuery(plan, exprs, exprId, plan.output)
           })
-          In(value, Seq(expr))
+          val subqueryOutput = expr.plan.output
+          val resolvedIn = InSubquery(values, expr.asInstanceOf[ListQuery])
+          if (values.length != subqueryOutput.length) {
+            throw new AnalysisException(
+              s"""Cannot analyze ${resolvedIn.sql}.
+                 |The number of columns in the left hand side of an IN 
subquery does not match the
+                 |number of columns in the output of subquery.
+                 |#columns in left hand side: ${values.length}
+                 |#columns in right hand side: ${subqueryOutput.length}
+                 |Left side columns:
+                 |[${values.map(_.sql).mkString(", ")}]
+                 |Right side columns:
+                 |[${subqueryOutput.map(_.sql).mkString(", ")}]""".stripMargin)
+          }
+          resolvedIn
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 7dd26b6..648aa9e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -449,15 +449,6 @@ object TypeCoercion {
    *    Analysis Exception will be raised at the type checking phase.
    */
   case class InConversion(conf: SQLConf) extends TypeCoercionRule {
-    private def flattenExpr(expr: Expression): Seq[Expression] = {
-      expr match {
-        // Multi columns in IN clause is represented as a CreateNamedStruct.
-        // flatten the named struct to get the list of expressions.
-        case cns: CreateNamedStruct => cns.valExprs
-        case expr => Seq(expr)
-      }
-    }
-
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
@@ -465,11 +456,9 @@ object TypeCoercion {
 
       // Handle type casting required between value expression and subquery 
output
       // in IN subquery.
-      case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
-        if !i.resolved && flattenExpr(a).length == sub.output.length =>
-        // LHS is the value expression of IN subquery.
-        val lhs = flattenExpr(a)
-
+      case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _))
+          if !i.resolved && lhs.length == sub.output.length =>
+        // LHS is the value expressions of IN subquery.
         // RHS is the subquery output.
         val rhs = sub.output
 
@@ -485,20 +474,13 @@ object TypeCoercion {
             case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
             case (e, _) => e
           }
-          val castedLhs = lhs.zip(commonTypes).map {
+          val newLhs = lhs.zip(commonTypes).map {
             case (e, dt) if e.dataType != dt => Cast(e, dt)
             case (e, _) => e
           }
 
-          // Before constructing the In expression, wrap the multi values in 
LHS
-          // in a CreatedNamedStruct.
-          val newLhs = castedLhs match {
-            case Seq(lhs) => lhs
-            case _ => CreateStruct(castedLhs)
-          }
-
           val newSub = Project(castedRhs, sub)
-          In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
+          InSubquery(newLhs, ListQuery(newSub, children, exprId, 
newSub.output))
         } else {
           i
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 2b582b5..d3ccd18 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -88,7 +88,13 @@ package object dsl {
     def <=> (other: Expression): Predicate = EqualNullSafe(expr, other)
     def =!= (other: Expression): Predicate = Not(EqualTo(expr, other))
 
-    def in(list: Expression*): Expression = In(expr, list)
+    def in(list: Expression*): Expression = list match {
+      case Seq(l: ListQuery) => expr match {
+          case c: CreateNamedStruct => InSubquery(c.valExprs, l)
+          case other => InSubquery(Seq(other), l)
+        }
+      case _ => In(expr, list)
+    }
 
     def like(other: Expression): Expression = Like(expr, other)
     def rlike(other: Expression): Expression = RLike(expr, other)

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
index 7541f52..fe6db8b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -87,8 +87,6 @@ object Canonicalize {
     case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r)
 
     // order the list in the In operator
-    // In subqueries contain only one element of type ListQuery. So checking 
that the length > 1
-    // we are not reordering In subqueries.
     case In(value, list) if list.length > 1 => In(value, 
list.sortBy(_.hashCode()))
 
     case _ => e

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f4077f78..149bd79 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -138,6 +138,66 @@ case class Not(child: Expression)
   override def sql: String = s"(NOT ${child.sql})"
 }
 
+/**
+ * Evaluates to `true` if `values` are returned in `query`'s result set.
+ */
+case class InSubquery(values: Seq[Expression], query: ListQuery)
+  extends Predicate with Unevaluable {
+
+  @transient lazy val value: Expression = if (values.length > 1) {
+    CreateNamedStruct(values.zipWithIndex.flatMap {
+      case (v: NamedExpression, _) => Seq(Literal(v.name), v)
+      case (v, idx) => Seq(Literal(s"_$idx"), v)
+    })
+  } else {
+    values.head
+  }
+
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val mismatchOpt = !DataType.equalsStructurally(query.dataType, 
value.dataType,
+      ignoreNullability = true)
+    if (mismatchOpt) {
+      if (values.length != query.childOutputs.length) {
+        TypeCheckResult.TypeCheckFailure(
+          s"""
+             |The number of columns in the left hand side of an IN subquery 
does not match the
+             |number of columns in the output of subquery.
+             |#columns in left hand side: ${values.length}.
+             |#columns in right hand side: ${query.childOutputs.length}.
+             |Left side columns:
+             |[${values.map(_.sql).mkString(", ")}].
+             |Right side columns:
+             |[${query.childOutputs.map(_.sql).mkString(", 
")}].""".stripMargin)
+      } else {
+        val mismatchedColumns = values.zip(query.childOutputs).flatMap {
+          case (l, r) if l.dataType != r.dataType =>
+            Seq(s"(${l.sql}:${l.dataType.catalogString}, 
${r.sql}:${r.dataType.catalogString})")
+          case _ => None
+        }
+        TypeCheckResult.TypeCheckFailure(
+          s"""
+             |The data type of one or more elements in the left hand side of 
an IN subquery
+             |is not compatible with the data type of the output of the 
subquery
+             |Mismatched columns:
+             |[${mismatchedColumns.mkString(", ")}]
+             |Left side:
+             |[${values.map(_.dataType.catalogString).mkString(", ")}].
+             |Right side:
+             |[${query.childOutputs.map(_.dataType.catalogString).mkString(", 
")}].""".stripMargin)
+      }
+    } else {
+      TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
+    }
+  }
+
+  override def children: Seq[Expression] = values :+ query
+  override def nullable: Boolean = children.exists(_.nullable)
+  override def foldable: Boolean = children.forall(_.foldable)
+  override def toString: String = s"$value IN ($query)"
+  override def sql: String = s"(${value.sql} IN (${query.sql}))"
+}
+
 
 /**
  * Evaluates to `true` if `list` contains `value`.
@@ -169,44 +229,8 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
     val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, 
value.dataType,
       ignoreNullability = true))
     if (mismatchOpt.isDefined) {
-      list match {
-        case ListQuery(_, _, _, childOutputs) :: Nil =>
-          val valExprs = value match {
-            case cns: CreateNamedStruct => cns.valExprs
-            case expr => Seq(expr)
-          }
-          if (valExprs.length != childOutputs.length) {
-            TypeCheckResult.TypeCheckFailure(
-              s"""
-                 |The number of columns in the left hand side of an IN 
subquery does not match the
-                 |number of columns in the output of subquery.
-                 |#columns in left hand side: ${valExprs.length}.
-                 |#columns in right hand side: ${childOutputs.length}.
-                 |Left side columns:
-                 |[${valExprs.map(_.sql).mkString(", ")}].
-                 |Right side columns:
-                 |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
-          } else {
-            val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
-              case (l, r) if l.dataType != r.dataType =>
-                Seq(s"(${l.sql}:${l.dataType.catalogString}, 
${r.sql}:${r.dataType.catalogString})")
-              case _ => None
-            }
-            TypeCheckResult.TypeCheckFailure(
-              s"""
-                 |The data type of one or more elements in the left hand side 
of an IN subquery
-                 |is not compatible with the data type of the output of the 
subquery
-                 |Mismatched columns:
-                 |[${mismatchedColumns.mkString(", ")}]
-                 |Left side:
-                 |[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
-                 |Right side:
-                 |[${childOutputs.map(_.dataType.catalogString).mkString(", 
")}].""".stripMargin)
-          }
-        case _ =>
-          TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but 
were: " +
-            s"${value.dataType.catalogString} != 
${mismatchOpt.get.dataType.catalogString}")
-      }
+      TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: 
" +
+        s"${value.dataType.catalogString} != 
${mismatchOpt.get.dataType.catalogString}")
     } else {
       TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
     }
@@ -307,9 +331,8 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
   }
 
   override def sql: String = {
-    val childrenSQL = children.map(_.sql)
-    val valueSQL = childrenSQL.head
-    val listSQL = childrenSQL.tail.mkString(", ")
+    val valueSQL = value.sql
+    val listSQL = list.map(_.sql).mkString(", ")
     s"($valueSQL IN ($listSQL))"
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 6acc87a..fc1caed 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -117,10 +117,10 @@ object SubExprUtils extends PredicateHelper {
   def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = {
     splitConjunctivePredicates(condition).exists {
       case _: Exists | Not(_: Exists) => false
-      case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false
+      case _: InSubquery | Not(_: InSubquery) => false
       case e => e.find { x =>
         x.isInstanceOf[Not] && e.find {
-          case In(_, Seq(_: ListQuery)) => true
+          case _: InSubquery => true
           case _ => false
         }.isDefined
       }.isDefined

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index e7b4730..5629b72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -523,6 +523,7 @@ object NullPropagation extends Rule[LogicalPlan] {
 
       // If the value expression is NULL then transform the In expression to 
null literal.
       case In(Literal(null, _), _) => Literal.create(null, BooleanType)
+      case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, 
BooleanType)
 
       // Non-leaf NullIntolerant expressions will return null, if at least one 
of its children is
       // a null literal.

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index de89e17..e9b7a8b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -42,13 +43,6 @@ import org.apache.spark.sql.types._
  *    condition.
  */
 object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper 
{
-  private def getValueExpression(e: Expression): Seq[Expression] = {
-    e match {
-      case cns : CreateNamedStruct => cns.valExprs
-      case expr => Seq(expr)
-    }
-  }
-
   private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
     // SPARK-21835: It is possibly that the two sides of the join have 
conflicting attributes,
     // the produced join then becomes unresolved and break structural 
integrity. We should
@@ -97,19 +91,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
           // Deduplicate conflicting attributes if any.
           dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
-        case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
-          val inConditions = 
getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+        case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
+          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ 
conditions, p)
           // Deduplicate conflicting attributes if any.
           dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
-        case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
+        case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
           // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
           // Construct the condition. A NULL in one of the conditions is 
regarded as a positive
           // result; such a row will be filtered out by the Anti-Join operator.
 
           // Note that will almost certainly be planned as a Broadcast Nested 
Loop join.
           // Use EXISTS if performance matters to you.
-          val inConditions = 
getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
           // Expand the NOT IN expression with the NULL-aware semantic
           // to its full form. That is from:
@@ -150,9 +144,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           newPlan = dedupJoin(
             Join(newPlan, sub, ExistenceJoin(exists), 
conditions.reduceLeftOption(And)))
           exists
-        case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
+        case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
           val exists = AttributeReference("exists", BooleanType, nullable = 
false)()
-          val inConditions = 
getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
+          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
           val newConditions = (inConditions ++ 
conditions).reduceLeftOption(And)
           // Deduplicate conflicting attributes if any.
           newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), 
newConditions))

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 732d762..7bc1f63 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1103,6 +1103,11 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
       case not => Not(e)
     }
 
+    def getValueExpressions(e: Expression): Seq[Expression] = e match {
+      case c: CreateNamedStruct => c.valExprs
+      case other => Seq(other)
+    }
+
     // Create the predicate.
     ctx.kind.getType match {
       case SqlBaseParser.BETWEEN =>
@@ -1111,7 +1116,7 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
           GreaterThanOrEqual(e, expression(ctx.lower)),
           LessThanOrEqual(e, expression(ctx.upper))))
       case SqlBaseParser.IN if ctx.query != null =>
-        invertIfNotDefined(In(e, Seq(ListQuery(plan(ctx.query)))))
+        invertIfNotDefined(InSubquery(getValueExpressions(e), 
ListQuery(plan(ctx.query))))
       case SqlBaseParser.IN =>
         invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
       case SqlBaseParser.LIKE =>

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 0a5194a..9477884 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -534,7 +534,7 @@ class AnalysisErrorSuite extends AnalysisTest {
     val a = AttributeReference("a", IntegerType)()
     val b = AttributeReference("b", IntegerType)()
     val plan = Project(
-      Seq(a, Alias(In(a, Seq(ListQuery(LocalRelation(b)))), "c")()),
+      Seq(a, Alias(InSubquery(Seq(a), ListQuery(LocalRelation(b))), "c")()),
       LocalRelation(a))
     assertAnalysisError(plan, "Predicate sub-queries can only be used in a 
Filter" :: Nil)
   }
@@ -543,12 +543,13 @@ class AnalysisErrorSuite extends AnalysisTest {
     val a = AttributeReference("a", IntegerType)()
     val b = AttributeReference("b", IntegerType)()
     val c = AttributeReference("c", BooleanType)()
-    val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), 
BooleanType),
+    val plan1 = Filter(Cast(Not(InSubquery(Seq(a), 
ListQuery(LocalRelation(b)))), BooleanType),
       LocalRelation(a))
     assertAnalysisError(plan1,
       "Null-aware predicate sub-queries cannot be used in nested conditions" 
:: Nil)
 
-    val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), 
LocalRelation(a, c))
+    val plan2 = Filter(
+      Or(Not(InSubquery(Seq(a), ListQuery(LocalRelation(b)))), c), 
LocalRelation(a, c))
     assertAnalysisError(plan2,
       "Null-aware predicate sub-queries cannot be used in nested conditions" 
:: Nil)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
index 1bf8d76..74a8590 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, 
OuterReference}
+import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, 
Project}
 
 /**
@@ -33,7 +33,8 @@ class ResolveSubquerySuite extends AnalysisTest {
   val t2 = LocalRelation(b)
 
   test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
-    val expr = Filter(In(a, 
Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
+    val expr = Filter(
+      InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), 
t2))), t1)
     val m = intercept[AnalysisException] {
       SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr))
     }.getMessage

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 86522a6..a36083b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -121,6 +121,21 @@ class OptimizeInSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("OptimizedIn test: NULL IN (subquery) gets transformed to 
Filter(null)") {
+    val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a")))
+    val originalQuery =
+      testRelation
+        .where(InSubquery(Seq(Literal.create(null, NullType)), subquery))
+        .analyze
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer =
+      testRelation
+        .where(Literal.create(null, BooleanType))
+        .analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
   test("OptimizedIn test: Inset optimization disabled as " +
     "list expression contains attribute)") {
     val originalQuery =

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
index 169b8737..8a5a551 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
+import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -42,7 +42,7 @@ class PullupCorrelatedPredicatesSuite extends PlanTest {
         .select('c)
     val outerQuery =
       testRelation
-        .where(In('a, Seq(ListQuery(correlatedSubquery))))
+        .where(InSubquery(Seq('a), ListQuery(correlatedSubquery)))
         .select('a).analyze
     assert(outerQuery.resolved)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index c37b9f1..781fc1e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -154,7 +154,19 @@ class ExpressionParserSuite extends PlanTest {
   test("in sub-query") {
     assertEqual(
       "a in (select b from c)",
-      In('a, Seq(ListQuery(table("c").select('b)))))
+      InSubquery(Seq('a), ListQuery(table("c").select('b))))
+
+    assertEqual(
+      "(a, b, c) in (select d, e, f from g)",
+      InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f))))
+
+    assertEqual(
+      "(a, b) in (select c from d)",
+      InSubquery(Seq('a, 'b), ListQuery(table("d").select('c))))
+
+    assertEqual(
+      "(a) in (select b from c)",
+      InSubquery(Seq('a), ListQuery(table("c").select('b))))
   }
 
   test("like expressions") {

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql
new file mode 100644
index 0000000..f4ffc20
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-basic.sql
@@ -0,0 +1,14 @@
+create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1);
+create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2);
+create temporary view struct_tab as select struct(col1 as a, col2 as b) as 
record from
+ values (1, 1), (1, 2), (2, 1), (2, 2);
+
+select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b);
+-- Invalid query, see SPARK-24341
+select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b);
+
+-- Aliasing is needed as a workaround for SPARK-24443
+select count(*) from struct_tab where record in
+  (select (a2 as a, b2 as b) from tab_b);
+select count(*) from struct_tab where record not in
+  (select (a2 as a, b2 as b) from tab_b);

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out
new file mode 100644
index 0000000..088db55
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-basic.sql.out
@@ -0,0 +1,70 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+create temporary view tab_a as select * from values (1, 1) as tab_a(a1, b1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+create temporary view tab_b as select * from values (1, 1) as tab_b(a2, b2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+create temporary view struct_tab as select struct(col1 as a, col2 as b) as 
record from
+ values (1, 1), (1, 2), (2, 1), (2, 2)
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+select 1 from tab_a where (a1, b1) not in (select a2, b2 from tab_b)
+-- !query 3 schema
+struct<1:int>
+-- !query 3 output
+
+
+
+-- !query 4
+select 1 from tab_a where (a1, b1) not in (select (a2, b2) from tab_b)
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+Cannot analyze (named_struct('a1', tab_a.`a1`, 'b1', tab_a.`b1`) IN 
(listquery())).
+The number of columns in the left hand side of an IN subquery does not match 
the
+number of columns in the output of subquery.
+#columns in left hand side: 2
+#columns in right hand side: 1
+Left side columns:
+[tab_a.`a1`, tab_a.`b1`]
+Right side columns:
+[`named_struct(a2, a2, b2, b2)`];
+
+
+-- !query 5
+select count(*) from struct_tab where record in
+  (select (a2 as a, b2 as b) from tab_b)
+-- !query 5 schema
+struct<count(1):bigint>
+-- !query 5 output
+1
+
+
+-- !query 6
+select count(*) from struct_tab where record not in
+  (select (a2 as a, b2 as b) from tab_b)
+-- !query 6 schema
+struct<count(1):bigint>
+-- !query 6 output
+3

http://git-wip-us.apache.org/repos/asf/spark/blob/88e0c7bb/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
index dcd3005..c52e570 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
@@ -92,15 +92,15 @@ t1a IN (SELECT t2a, t2b
 struct<>
 -- !query 7 output
 org.apache.spark.sql.AnalysisException
-cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type 
mismatch: 
+Cannot analyze (t1.`t1a` IN (listquery(t1.`t1a`))).
 The number of columns in the left hand side of an IN subquery does not match 
the
 number of columns in the output of subquery.
-#columns in left hand side: 1.
-#columns in right hand side: 2.
+#columns in left hand side: 1
+#columns in right hand side: 2
 Left side columns:
-[t1.`t1a`].
+[t1.`t1a`]
 Right side columns:
-[t2.`t2a`, t2.`t2b`].;
+[t2.`t2a`, t2.`t2b`];
 
 
 -- !query 8
@@ -113,15 +113,15 @@ WHERE
 struct<>
 -- !query 8 output
 org.apache.spark.sql.AnalysisException
-cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN 
(listquery(t1.`t1a`)))' due to data type mismatch: 
+Cannot analyze (named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN 
(listquery(t1.`t1a`))).
 The number of columns in the left hand side of an IN subquery does not match 
the
 number of columns in the output of subquery.
-#columns in left hand side: 2.
-#columns in right hand side: 1.
+#columns in left hand side: 2
+#columns in right hand side: 1
 Left side columns:
-[t1.`t1a`, t1.`t1b`].
+[t1.`t1a`, t1.`t1b`]
 Right side columns:
-[t2.`t2a`].;
+[t2.`t2a`];
 
 
 -- !query 9


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to