Repository: spark Updated Branches: refs/heads/branch-2.4 53eb85854 -> 3c78ea258
[SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function ## What changes were proposed in this pull request? In ElementAt, when first argument is MapType, we should coerce the key type and the second argument based on findTightestCommonType. This is not happening currently. We may produce wrong output as we will incorrectly downcast the right hand side double expression to int. ```SQL spark-sql> select element_at(map(1,"one", 2, "two"), 2.2); two ``` Also, when the first argument is ArrayType, the second argument should be an integer type or a smaller integral type that can be safely casted to an integer type. Currently we may do an unsafe cast. In the following case, we should fail with an error as 2.2 is not a integer index. But instead we down cast it to int currently and return a result instead. ```SQL spark-sql> select element_at(array(1,2), 1.24D); 1 ``` This PR also supports implicit cast between two MapTypes. I have followed similar logic that exists today to do implicit casts between two array types. ## How was this patch tested? Added new tests in DataFrameFunctionSuite, TypeCoercionSuite. Closes #22544 from dilipbiswal/SPARK-25522. Authored-by: Dilip Biswal <dbis...@us.ibm.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit d03e0af80d7659f12821cc2442efaeaee94d3985) Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3c78ea25 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3c78ea25 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3c78ea25 Branch: refs/heads/branch-2.4 Commit: 3c78ea2589e1e2f3824ae6fa273eceaee3934391 Parents: 53eb858 Author: Dilip Biswal <dbis...@us.ibm.com> Authored: Thu Sep 27 15:04:59 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Thu Sep 27 19:50:01 2018 +0800 ---------------------------------------------------------------------- .../sql/catalyst/analysis/TypeCoercion.scala | 19 +++++ .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/collectionOperations.scala | 37 ++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 43 +++++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 75 +++++++++++++++++++- 5 files changed, 154 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3c78ea25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 49d286f..72ac80e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -950,6 +950,25 @@ object TypeCoercion { if !Cast.forceNullable(fromType, toType) => implicitCast(fromType, toType).map(ArrayType(_, false)).orNull + // Implicit cast between Map types. + // Follows the same semantics of implicit casting between two array types. + // Refer to documentation above. Make sure that both key and values + // can not be null after the implicit cast operation by calling forceNullable + // method. + case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) + if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) => + if (Cast.forceNullable(fromValueType, toValueType) && !tn) { + null + } else { + val newKeyType = implicitCast(fromKeyType, toKeyType).orNull + val newValueType = implicitCast(fromValueType, toValueType).orNull + if (newKeyType != null && newValueType != null) { + MapType(newKeyType, newValueType, tn) + } else { + null + } + } + case _ => null } Option(ret) http://git-wip-us.apache.org/repos/asf/spark/blob/3c78ea25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8f77799..ee463bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -183,7 +183,7 @@ object Cast { case _ => false } - private def resolvableNullability(from: Boolean, to: Boolean) = !from || to + def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to } /** http://git-wip-us.apache.org/repos/asf/spark/blob/3c78ea25/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9cc7dba..b24d748 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(ArrayType, MapType), - left.dataType match { - case _: ArrayType => IntegerType - case _: MapType => mapKeyType - case _ => AnyDataType // no match for a wrong 'left' expression type - } - ) + (left.dataType, right.dataType) match { + case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) => + Seq(arr, IntegerType) + case (MapType(keyType, valueType, hasNull), e2) => + TypeCoercion.findTightestCommonType(keyType, e2) match { + case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt) + case _ => Seq.empty + } + case (l, r) => Seq.empty + + } } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") - case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + (left.dataType, right.dataType) match { + case (_: ArrayType, e2) if e2 != IntegerType => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => + TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${MapType.simpleString} followed by a value of same key type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) => + TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " + + s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " + + s"${left.dataType.catalogString} type.") + case _ => TypeCheckResult.TypeCheckSuccess } } http://git-wip-us.apache.org/repos/asf/spark/blob/3c78ea25/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0594673..0eba1c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - MapType(StringType, StringType)") { - val checkedType = MapType(StringType, StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) + test("implicit type cast between two Map types") { + val sourceType = MapType(IntegerType, IntegerType, true) + val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _)) + val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, _)).map { t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t => + MapType(t, sourceType.valueType, valueContainsNull = true) + } + + // Tests that its possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Byte, Short, + // Long, Double, Float, Decimal(38, 18) or String. + targetTypes.foreach { targetType => + shouldCast(sourceType, targetType, targetType) + } + + // Tests that its not possible to setup implicit casts between two map types when + // source map's key type is integer and the target map's key type are either Binary, + // Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or NullType + nonCastableTargetTypes.foreach { targetType => + shouldNotCast(sourceType, targetType) + } + + // Tests that its not possible to cast from nullable map type to not nullable map type. + val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t => + MapType(t, sourceType.valueType, valueContainsNull = false) + } + val sourceMapExprWithValueNull = + CreateMap(Seq(Literal.default(sourceType.keyType), + Literal.create(null, sourceType.valueType))) + targetNotNullableTypes.foreach { targetType => + val castDefault = + TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType) + assert(castDefault.isEmpty, + s"Should not be able to cast $sourceType to $targetType, but got $castDefault") + } } test("implicit type cast - StructType().add(\"a1\", StringType)") { http://git-wip-us.apache.org/repos/asf/spark/blob/3c78ea25/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 88dbae8..60ebc5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row("3"), Row(""), Row(null)) ) - val e = intercept[AnalysisException] { + val e1 = intercept[AnalysisException] { Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") } - assert(e.message.contains( - "argument 1 requires (array or map) type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |The first argument to function element_at should have been array or map type, but + |its string type. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"), + Seq(Row(1)) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"), + Seq(Row(3)) + ) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)") + } + val errorMsg2 = + s""" + |Input to function element_at should have been array followed by a int, but it's + |[array<string>, bigint]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"), + Seq(Row("b")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"), + Seq(Row("a")) + ) + + checkAnswer( + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"), + Seq(Row(null)) + ) + + val e3 = intercept[AnalysisException] { + OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')") + } + val errorMsg3 = + s""" + |Input to function element_at should have been map followed by a value of same + |key type, but it's [map<int,string>, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e3.message.contains(errorMsg3)) } test("array_union functions") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org