This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new bd94cf7  [SPARK-31227][SQL] Non-nullable null type in complex types 
should not coerce to nullable type
bd94cf7 is described below

commit bd94cf7bcccd305cac1f301d2cefb043afba35f6
Author: HyukjinKwon <gurwls...@apache.org>
AuthorDate: Thu Mar 26 15:42:54 2020 +0800

    [SPARK-31227][SQL] Non-nullable null type in complex types should not 
coerce to nullable type
    
    ### What changes were proposed in this pull request?
    
    This PR targets for non-nullable null type not to coerce to nullable type 
in complex types.
    
    Non-nullable fields in struct, elements in an array and entries in map can 
mean empty array, struct and map. They are empty so it does not need to force 
the nullability when we find common types.
    
    This PR also reverts and supersedes 
https://github.com/apache/spark/commit/d7b97a1d0daf65710317321490a833f696a46f21
    
    ### Why are the changes needed?
    
    To make type coercion coherent and consistent. Currently, we correctly keep 
the nullability even between non-nullable fields:
    
    ```scala
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions._
    spark.range(1).select(array(lit(1)).cast(ArrayType(IntegerType, 
false))).printSchema()
    spark.range(1).select(array(lit(1)).cast(ArrayType(DoubleType, 
false))).printSchema()
    ```
    ```scala
    spark.range(1).selectExpr("concat(array(1), array(1)) as arr").printSchema()
    ```
    
    ### Does this PR introduce any user-facing change?
    
    Yes.
    
    ```scala
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions._
    spark.range(1).select(array().cast(ArrayType(IntegerType, 
false))).printSchema()
    ```
    ```scala
    spark.range(1).selectExpr("concat(array(), array(1)) as arr").printSchema()
    ```
    
    **Before:**
    
    ```
    org.apache.spark.sql.AnalysisException: cannot resolve 'array()' due to 
data type mismatch: cannot cast array<null> to array<int>;;
    'Project [cast(array() as array<int>) AS array()#68]
    +- Range (0, 1, step=1, splits=Some(12))
    
      at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:149)
      at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:140)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:333)
      at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:333)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
      at 
org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    ```
    
    ```
    root
     |-- arr: array (nullable = false)
     |    |-- element: integer (containsNull = true)
    ```
    
    **After:**
    
    ```
    root
     |-- array(): array (nullable = false)
     |    |-- element: integer (containsNull = false)
    ```
    
    ```
    root
     |-- arr: array (nullable = false)
     |    |-- element: integer (containsNull = false)
    ```
    
    ### How was this patch tested?
    
    Unittests were added and manually tested.
    
    Closes #27991 from HyukjinKwon/SPARK-31227.
    
    Authored-by: HyukjinKwon <gurwls...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 3bd10ce007832522e38583592b6f358e185cdb7d)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  2 +-
 .../spark/sql/catalyst/expressions/Cast.scala      | 17 +++----
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  | 54 ++++++++++++++--------
 .../spark/sql/catalyst/expressions/CastSuite.scala | 22 +++++++++
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  7 +++
 5 files changed, 73 insertions(+), 29 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 5a5d7c6..eb9a4d4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -160,7 +160,7 @@ object TypeCoercion {
       }
     case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, 
valueContainsNull2)) =>
       findTypeFunc(kt1, kt2)
-        .filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && 
Cast.canCastMapKeyNullSafe(kt2, kt) }
+        .filter { kt => !Cast.forceNullable(kt1, kt) && 
!Cast.forceNullable(kt2, kt) }
         .flatMap { kt =>
           findTypeFunc(vt1, vt2).map { vt =>
             MapType(kt, vt, valueContainsNull1 || valueContainsNull2 ||
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8177136..8d82956 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -77,7 +77,8 @@ object Cast {
         resolvableNullability(fn || forceNullable(fromType, toType), tn)
 
     case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
-      canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) &&
+      canCast(fromKey, toKey) &&
+        (!forceNullable(fromKey, toKey)) &&
         canCast(fromValue, toValue) &&
         resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
 
@@ -97,11 +98,6 @@ object Cast {
     case _ => false
   }
 
-  def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = {
-    // If the original map key type is NullType, it's OK as the map must be 
empty.
-    fromType == NullType || !forceNullable(fromType, toType)
-  }
-
   /**
    * Return true if we need to use the `timeZone` information casting `from` 
type to `to` type.
    * The patterns matched reflect the current implementation in the Cast node.
@@ -210,8 +206,13 @@ object Cast {
     case _ => false  // overflow
   }
 
+  /**
+   * Returns `true` if casting non-nullable values from `from` type to `to` 
type
+   * may return null. Note that the caller side should take care of input 
nullability
+   * first and only call this method if the input is not nullable.
+   */
   def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
-    case (NullType, _) => true
+    case (NullType, _) => false // empty array or map case
     case (_, _) if from == to => false
 
     case (StringType, BinaryType) => false
@@ -269,7 +270,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     }
   }
 
-  override def nullable: Boolean = Cast.forceNullable(child.dataType, 
dataType) || child.nullable
+  override def nullable: Boolean = child.nullable || 
Cast.forceNullable(child.dataType, dataType)
 
   protected def ansiEnabled: Boolean
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 95005fd..ab21a9e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
 class TypeCoercionSuite extends AnalysisTest {
+  import TypeCoercionSuite._
 
   // scalastyle:off line.size.limit
   // The following table shows all implicit data type conversions that are not 
visible to the user.
@@ -99,22 +100,6 @@ class TypeCoercionSuite extends AnalysisTest {
     case _ => Literal.create(null, dataType)
   }
 
-  val integralTypes: Seq[DataType] =
-    Seq(ByteType, ShortType, IntegerType, LongType)
-  val fractionalTypes: Seq[DataType] =
-    Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
-  val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
-  val atomicTypes: Seq[DataType] =
-    numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, 
TimestampType)
-  val complexTypes: Seq[DataType] =
-    Seq(ArrayType(IntegerType),
-      ArrayType(StringType),
-      MapType(StringType, StringType),
-      new StructType().add("a1", StringType),
-      new StructType().add("a1", StringType).add("a2", IntegerType))
-  val allTypes: Seq[DataType] =
-    atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)
-
   // Check whether the type `checkedType` can be cast to all the types in 
`castableTypes`,
   // but cannot be cast to the other types in `allTypes`.
   private def checkTypeCasting(checkedType: DataType, castableTypes: 
Seq[DataType]): Unit = {
@@ -497,6 +482,23 @@ class TypeCoercionSuite extends AnalysisTest {
         .add("null", IntegerType, nullable = false),
       Some(new StructType()
         .add("null", IntegerType, nullable = true)))
+
+    widenTest(
+      ArrayType(NullType, containsNull = false),
+      ArrayType(IntegerType, containsNull = false),
+      Some(ArrayType(IntegerType, containsNull = false)))
+
+    widenTest(MapType(NullType, NullType, false),
+      MapType(IntegerType, StringType, false),
+      Some(MapType(IntegerType, StringType, false)))
+
+    widenTest(
+      new StructType()
+        .add("null", NullType, nullable = false),
+      new StructType()
+        .add("null", IntegerType, nullable = false),
+      Some(new StructType()
+        .add("null", IntegerType, nullable = false)))
   }
 
   test("wider common type for decimal and array") {
@@ -728,8 +730,6 @@ class TypeCoercionSuite extends AnalysisTest {
   }
 
   test("cast NullType for expressions that implement ExpectsInputTypes") {
-    import TypeCoercionSuite._
-
     ruleTest(TypeCoercion.ImplicitTypeCasts,
       AnyTypeUnaryExpression(Literal.create(null, NullType)),
       AnyTypeUnaryExpression(Literal.create(null, NullType)))
@@ -740,8 +740,6 @@ class TypeCoercionSuite extends AnalysisTest {
   }
 
   test("cast NullType for binary operators") {
-    import TypeCoercionSuite._
-
     ruleTest(TypeCoercion.ImplicitTypeCasts,
       AnyTypeBinaryOperator(Literal.create(null, NullType), 
Literal.create(null, NullType)),
       AnyTypeBinaryOperator(Literal.create(null, NullType), 
Literal.create(null, NullType)))
@@ -1548,6 +1546,22 @@ class TypeCoercionSuite extends AnalysisTest {
 
 object TypeCoercionSuite {
 
+  val integralTypes: Seq[DataType] =
+    Seq(ByteType, ShortType, IntegerType, LongType)
+  val fractionalTypes: Seq[DataType] =
+    Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
+  val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
+  val atomicTypes: Seq[DataType] =
+    numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, 
TimestampType)
+  val complexTypes: Seq[DataType] =
+    Seq(ArrayType(IntegerType),
+      ArrayType(StringType),
+      MapType(StringType, StringType),
+      new StructType().add("a1", StringType),
+      new StructType().add("a1", StringType).add("a2", IntegerType))
+  val allTypes: Seq[DataType] =
+    atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)
+
   case class AnyTypeUnaryExpression(child: Expression)
     extends UnaryExpression with ExpectsInputTypes with Unevaluable {
     override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index dd3f362..e25d805 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
+import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
 import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, 
CollectSet}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -413,6 +414,14 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
       assert(ret.resolved)
       checkEvaluation(ret, Seq(null, true, false, null))
     }
+
+    {
+      val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = 
false))
+      val ret = cast(array, ArrayType(IntegerType, containsNull = false))
+      assert(ret.resolved)
+      checkEvaluation(ret, Seq.empty)
+    }
+
     {
       val ret = cast(array, ArrayType(BooleanType, containsNull = false))
       assert(ret.resolved === false)
@@ -1158,6 +1167,19 @@ class CastSuite extends CastSuiteBase {
       StructType(StructField("a", IntegerType, true) :: Nil)))
   }
 
+  test("SPARK-31227: Non-nullable null type should not coerce to nullable 
type") {
+    TypeCoercionSuite.allTypes.foreach { t =>
+      assert(Cast.canCast(ArrayType(NullType, false), ArrayType(t, false)))
+
+      assert(Cast.canCast(
+        MapType(NullType, NullType, false), MapType(t, t, false)))
+
+      assert(Cast.canCast(
+        StructType(StructField("a", NullType, false) :: Nil),
+        StructType(StructField("a", t, false) :: Nil)))
+    }
+  }
+
   test("Cast should output null for invalid strings when ANSI is not 
enabled.") {
     withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
       checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 875a671..9100ee3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1533,6 +1533,13 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     assert(e.getMessage.contains("string, binary or array"))
   }
 
+  test("SPARK-31227: Non-nullable null type should not coerce to nullable type 
in concat") {
+    val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr")
+    val expected = spark.range(1).selectExpr("array(1) as arr")
+    checkAnswer(actual, expected)
+    assert(actual.schema === expected.schema)
+  }
+
   test("flatten function") {
     // Test cases with a primitive type
     val intDF = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to