Repository: spark Updated Branches: refs/heads/master 804515f82 -> bb49661e1
[SPARK-25416][SQL] ArrayPosition function may return incorrect result when right expression is implicitly down casted ## What changes were proposed in this pull request? In ArrayPosition, we currently cast the right hand side expression to match the element type of the left hand side Array. This may result in down casting and may return wrong result or questionable result. Example : ```SQL spark-sql> select array_position(array(1), 1.34); 1 ``` ```SQL spark-sql> select array_position(array(1), 'foo'); null ``` We should safely coerce both left and right hand side expressions. ## How was this patch tested? Added tests in DataFrameFunctionsSuite Closes #22407 from dilipbiswal/SPARK-25416. Authored-by: Dilip Biswal <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bb49661e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bb49661e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bb49661e Branch: refs/heads/master Commit: bb49661e192eed78a8a306deffd83c73bd4a9eff Parents: 804515f Author: Dilip Biswal <[email protected]> Authored: Mon Sep 24 21:37:51 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Mon Sep 24 21:37:51 2018 +0800 ---------------------------------------------------------------------- .../expressions/collectionOperations.scala | 21 +++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 57 +++++++++++++++++--- 2 files changed, 64 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/bb49661e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 161adc9..85bc1cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2071,18 +2071,23 @@ case class ArrayPosition(left: Expression, right: Expression) override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = { - val elementType = left.dataType match { - case t: ArrayType => t.elementType - case _ => AnyDataType + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull), e2) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull), dt) + case _ => Seq.empty + } + case _ => Seq.empty } - Seq(ArrayType, elementType) } override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case f: TypeCheckResult.TypeCheckFailure => f - case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), e2) if e1.sameType(e2) => + TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") } } http://git-wip-us.apache.org/repos/asf/spark/blob/bb49661e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ad52fd0..fd71f24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1097,18 +1097,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) checkAnswer( - df.selectExpr("array_position(array(array(1), null)[0], 1)"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.23D)"), + Seq(Row(0L)) ) + checkAnswer( - df.selectExpr("array_position(array(1, null), array(1, null)[0])"), - Seq(Row(1L), Row(1L)) + OneRowRelation().selectExpr("array_position(array(1), 1.0D)"), + Seq(Row(1L)) ) - val e = intercept[AnalysisException] { + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.D), 1)"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1.23D), 1)"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"), + Seq(Row(1L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"), + Seq(Row(0L)) + ) + + checkAnswer( + OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L)) + ) + checkAnswer( + OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L)) + ) + + val e1 = intercept[AnalysisException] { Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") } - assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + val errorMsg1 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [string, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e1.message.contains(errorMsg1)) + + val e2 = intercept[AnalysisException] { + OneRowRelation().selectExpr("array_position(array(1), '1')") + } + val errorMsg2 = + s""" + |Input to function array_position should have been array followed by a + |value with same element type, but it's [array<int>, string]. + """.stripMargin.replace("\n", " ").trim() + assert(e2.message.contains(errorMsg2)) } test("element_at function") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
