Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/16954#discussion_r103850635
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
---
@@ -365,17 +385,66 @@ object TypeCoercion {
}
/**
- * Convert the value and in list expressions to the common operator type
- * by looking at all the argument types and finding the closest one that
- * all the arguments can be cast to. When no common operator type is
found
- * the original expression will be returned and an Analysis Exception
will
- * be raised at type checking phase.
+ * Handles type coercion for both IN expression with subquery and IN
+ * expressions without subquery.
+ * 1. In the first case, find the common type by comparing the left hand
side
+ * expression types against corresponding right hand side expression
derived
+ * from the subquery expression's plan output. Inject appropriate
casts in the
+ * LHS and RHS side of IN expression.
+ *
+ * 2. In the second case, convert the value and in list expressions to
the
+ * common operator type by looking at all the argument types and
finding
+ * the closest one that all the arguments can be cast to. When no
common
+ * operator type is found the original expression will be returned
and an
+ * Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
+ // Handle type casting required between value expression and
subquery output
+ // in IN subquery.
+ case i @ In(a, Seq(ListQuery(sub, children, exprId))) if !i.resolved
=>
+ // lhs is the value expression of IN subquery.
+ val lhs = a match {
+ // Multi columns in IN clause is represented as a
CreateNamedStruct.
+ // flatten the named struct to get the list of expressions.
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
+ }
+
+ // rhs is the subquery output.
+ val rhs = sub.output
+ require(lhs.length == rhs.length)
+
+ val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
+ findCommonTypeForBinaryComparison(l.dataType, r.dataType)
+ }
+
+ if (commonTypes.length == lhs.length) {
+ val castedRhs = rhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Alias(Cast(e, dt),
e.name)()
+ case (e, _) => e
+ }
+ val castedLhs = lhs.zip(commonTypes).map {
+ case (e, dt) if e.dataType != dt => Cast(e, dt)
+ case (e, _) => e
+ }
+
+ // Before constructing the In expression, wrap the multi values
in lhs
+ // in a CreatedNamedStruct.
+ val newLhs = a match {
+ case cns: CreateNamedStruct =>
+ val nameValue = cns.nameExprs.zip(castedLhs).flatMap(pair =>
Seq(pair._1, pair._2))
--- End diff --
Please use `case (name, value) =>` instead of `pair`
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]