Repository: spark
Updated Branches:
  refs/heads/master e71e93aaa -> 01fcba2c6


[SPARK-24737][SQL] Type coercion between StructTypes.

## What changes were proposed in this pull request?

We can support type coercion between `StructType`s where all the internal types 
are compatible.

## How was this patch tested?

Added tests.

Author: Takuya UESHIN <[email protected]>

Closes #21713 from ueshin/issues/SPARK-24737/structtypecoercion.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01fcba2c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01fcba2c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01fcba2c

Branch: refs/heads/master
Commit: 01fcba2c685be0603a404392685e9d52fb4cb82a
Parents: e71e93a
Author: Takuya UESHIN <[email protected]>
Authored: Fri Jul 6 11:10:50 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Fri Jul 6 11:10:50 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 69 ++++++---------
 .../catalyst/analysis/TypeCoercionSuite.scala   | 93 ++++++++++++++++++--
 2 files changed, 114 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/01fcba2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index cf90e6e..b6ca30c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -102,25 +102,7 @@ object TypeCoercion {
     case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
       Some(TimestampType)
 
-    case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if 
t1.sameType(t2) =>
-      Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
-        // Since `t1.sameType(t2)` is true, two StructTypes have the same 
DataType
-        // except `name` (in case of `spark.sql.caseSensitive=false`) and 
`nullable`.
-        // - Different names: use f1.name
-        // - Different nullabilities: `nullable` is true iff one of them is 
nullable.
-        val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
-        StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
-      }))
-
-    case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if 
a1.sameType(a2) =>
-      findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
-
-    case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) 
if m1.sameType(m2) =>
-      val keyType = findTightestCommonType(kt1, kt2)
-      val valueType = findTightestCommonType(vt1, vt2)
-      Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
-
-    case _ => None
+    case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
   }
 
   /** Promotes all the way to StringType. */
@@ -166,6 +148,30 @@ object TypeCoercion {
     case (l, r) => None
   }
 
+  private def findTypeForComplex(
+      t1: DataType,
+      t2: DataType,
+      findTypeFunc: (DataType, DataType) => Option[DataType]): 
Option[DataType] = (t1, t2) match {
+    case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
+      findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
+    case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, 
valueContainsNull2)) =>
+      findTypeFunc(kt1, kt2).flatMap { kt =>
+        findTypeFunc(vt1, vt2).map { vt =>
+          MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
+        }
+      }
+    case (StructType(fields1), StructType(fields2)) if fields1.length == 
fields2.length =>
+      val resolver = SQLConf.get.resolver
+      fields1.zip(fields2).foldLeft(Option(new StructType())) {
+        case (Some(struct), (field1, field2)) if resolver(field1.name, 
field2.name) =>
+          findTypeFunc(field1.dataType, field2.dataType).map {
+            dt => struct.add(field1.name, dt, field1.nullable || 
field2.nullable)
+          }
+        case _ => None
+      }
+    case _ => None
+  }
+
   /**
    * Case 2 type widening (see the classdoc comment above for TypeCoercion).
    *
@@ -176,17 +182,7 @@ object TypeCoercion {
     findTightestCommonType(t1, t2)
       .orElse(findWiderTypeForDecimal(t1, t2))
       .orElse(stringPromotion(t1, t2))
-      .orElse((t1, t2) match {
-        case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
-          findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || 
containsNull2))
-        case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, 
valueContainsNull2)) =>
-          findWiderTypeForTwo(kt1, kt2).flatMap { kt =>
-            findWiderTypeForTwo(vt1, vt2).map { vt =>
-              MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
-            }
-          }
-        case _ => None
-      })
+      .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo))
   }
 
   /**
@@ -222,18 +218,7 @@ object TypeCoercion {
       t2: DataType): Option[DataType] = {
     findTightestCommonType(t1, t2)
       .orElse(findWiderTypeForDecimal(t1, t2))
-      .orElse((t1, t2) match {
-        case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
-          findWiderTypeWithoutStringPromotionForTwo(et1, et2)
-            .map(ArrayType(_, containsNull1 || containsNull2))
-        case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, 
valueContainsNull2)) =>
-          findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt =>
-            findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt =>
-              MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
-            }
-          }
-        case _ => None
-      })
+      .orElse(findTypeForComplex(t1, t2, 
findWiderTypeWithoutStringPromotionForTwo))
   }
 
   def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): 
Option[DataType] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/01fcba2c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 4e5ca1b..8cc5a23 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest {
   // | NullType             | ByteType | ShortType | IntegerType | LongType | 
DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | 
DateType | TimestampType | ArrayType  | MapType  | StructType  | NullType | 
CalendarIntervalType | DecimalType(38, 18) | DoubleType  | IntegerType  |
   // | CalendarIntervalType | X        | X         | X           | X        | 
X          | X         | X          | X          | X           | X          | X 
       | X             | X          | X        | X           | X        | 
CalendarIntervalType | X                   | X           | X            |
   // 
+----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
-  // Note: StructType* is castable only when the internal child types also 
match; otherwise, not castable.
+  // Note: StructType* is castable when all the internal child types are 
castable according to the table.
   // Note: ArrayType* is castable when the element type is castable according 
to the table.
   // Note: MapType* is castable when both the key type and the value type are 
castable according to the table.
   // scalastyle:on line.size.limit
@@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest {
     widenTest(
       StructType(Seq(StructField("a", IntegerType, nullable = false))),
       StructType(Seq(StructField("a", DoubleType, nullable = false))),
-      None)
+      Some(StructType(Seq(StructField("a", DoubleType, nullable = false)))))
 
     widenTest(
       StructType(Seq(StructField("a", IntegerType, nullable = false))),
@@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest {
     def widenTestWithStringPromotion(
         t1: DataType,
         t2: DataType,
-        expected: Option[DataType]): Unit = {
-      checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
+        expected: Option[DataType],
+        isSymmetric: Boolean = true): Unit = {
+      checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, 
isSymmetric)
     }
 
     def widenTestWithoutStringPromotion(
         t1: DataType,
         t2: DataType,
-        expected: Option[DataType]): Unit = {
-      checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, 
t1, t2, expected)
+        expected: Option[DataType],
+        isSymmetric: Boolean = true): Unit = {
+      checkWidenType(
+        TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, 
expected, isSymmetric)
     }
 
     // Decimal
@@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest {
       ArrayType(MapType(IntegerType, FloatType), containsNull = false),
       ArrayType(MapType(LongType, DoubleType), containsNull = false),
       Some(ArrayType(MapType(LongType, DoubleType), containsNull = false)))
+    widenTestWithStringPromotion(
+      ArrayType(new StructType().add("num", ShortType), containsNull = false),
+      ArrayType(new StructType().add("num", LongType), containsNull = false),
+      Some(ArrayType(new StructType().add("num", LongType), containsNull = 
false)))
 
     // MapType
     widenTestWithStringPromotion(
@@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest {
       MapType(IntegerType, MapType(ShortType, TimestampType), 
valueContainsNull = false),
       MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = 
false),
       Some(MapType(LongType, MapType(DoubleType, StringType), 
valueContainsNull = false)))
+    widenTestWithStringPromotion(
+      MapType(IntegerType, new StructType().add("num", ShortType), 
valueContainsNull = false),
+      MapType(LongType, new StructType().add("num", LongType), 
valueContainsNull = false),
+      Some(MapType(LongType, new StructType().add("num", LongType), 
valueContainsNull = false)))
+
+    // StructType
+    widenTestWithStringPromotion(
+      new StructType()
+        .add("num", ShortType, nullable = true).add("ts", StringType, nullable 
= false),
+      new StructType()
+        .add("num", DoubleType, nullable = false).add("ts", TimestampType, 
nullable = true),
+      Some(new StructType()
+        .add("num", DoubleType, nullable = true).add("ts", StringType, 
nullable = true)))
+    widenTestWithStringPromotion(
+      new StructType()
+        .add("arr", ArrayType(ShortType, containsNull = false), nullable = 
false),
+      new StructType()
+        .add("arr", ArrayType(DoubleType, containsNull = true), nullable = 
false),
+      Some(new StructType()
+        .add("arr", ArrayType(DoubleType, containsNull = true), nullable = 
false)))
+    widenTestWithStringPromotion(
+      new StructType()
+        .add("map", MapType(ShortType, TimestampType, valueContainsNull = 
true), nullable = false),
+      new StructType()
+        .add("map", MapType(DoubleType, StringType, valueContainsNull = 
false), nullable = false),
+      Some(new StructType()
+        .add("map", MapType(DoubleType, StringType, valueContainsNull = true), 
nullable = false)))
+
+    widenTestWithStringPromotion(
+      new StructType().add("num", IntegerType),
+      new StructType().add("num", LongType).add("str", StringType),
+      None)
+    widenTestWithoutStringPromotion(
+      new StructType().add("num", IntegerType),
+      new StructType().add("num", LongType).add("str", StringType),
+      None)
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      widenTestWithStringPromotion(
+        new StructType().add("a", IntegerType),
+        new StructType().add("A", LongType),
+        None)
+      widenTestWithoutStringPromotion(
+        new StructType().add("a", IntegerType),
+        new StructType().add("A", LongType),
+        None)
+    }
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      widenTestWithStringPromotion(
+        new StructType().add("a", IntegerType),
+        new StructType().add("A", LongType),
+        Some(new StructType().add("a", LongType)),
+        isSymmetric = false)
+      widenTestWithoutStringPromotion(
+        new StructType().add("a", IntegerType),
+        new StructType().add("A", LongType),
+        Some(new StructType().add("a", LongType)),
+        isSymmetric = false)
+    }
 
     // Without string promotion
     widenTestWithoutStringPromotion(IntegerType, StringType, None)
@@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest {
       MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), 
None)
     widenTestWithoutStringPromotion(
       MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), 
None)
+    widenTestWithoutStringPromotion(
+      new StructType().add("a", IntegerType),
+      new StructType().add("a", StringType),
+      None)
+    widenTestWithoutStringPromotion(
+      new StructType().add("a", StringType),
+      new StructType().add("a", IntegerType),
+      None)
 
     // String promotion
     widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
@@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest {
       MapType(IntegerType, StringType),
       MapType(IntegerType, TimestampType),
       Some(MapType(IntegerType, StringType)))
+    widenTestWithStringPromotion(
+      new StructType().add("a", IntegerType),
+      new StructType().add("a", StringType),
+      Some(new StructType().add("a", StringType)))
+    widenTestWithStringPromotion(
+      new StructType().add("a", StringType),
+      new StructType().add("a", IntegerType),
+      Some(new StructType().add("a", StringType)))
   }
 
   private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, 
transformed: Expression) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to