Repository: spark
Updated Branches:
  refs/heads/branch-2.4 13bc58d28 -> 51d5378f8


[SPARK-25416][SQL] ArrayPosition function may return incorrect result when 
right expression is implicitly down casted

## What changes were proposed in this pull request?
In ArrayPosition, we currently cast the right hand side expression to match the 
element type of the left hand side Array. This may result in down casting and 
may return wrong result or questionable result.

Example :
```SQL
spark-sql> select array_position(array(1), 1.34);
1
```
```SQL
spark-sql> select array_position(array(1), 'foo');
null
```

We should safely coerce both left and right hand side expressions.
## How was this patch tested?
Added tests in DataFrameFunctionsSuite

Closes #22407 from dilipbiswal/SPARK-25416.

Authored-by: Dilip Biswal <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit bb49661e192eed78a8a306deffd83c73bd4a9eff)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/51d5378f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/51d5378f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/51d5378f

Branch: refs/heads/branch-2.4
Commit: 51d5378f8956e0abc1bae0ffdea78637011935a1
Parents: 13bc58d
Author: Dilip Biswal <[email protected]>
Authored: Mon Sep 24 21:37:51 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Mon Sep 24 21:43:34 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 21 +++++---
 .../spark/sql/DataFrameFunctionsSuite.scala     | 57 +++++++++++++++++---
 2 files changed, 64 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/51d5378f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 161adc9..85bc1cd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2071,18 +2071,23 @@ case class ArrayPosition(left: Expression, right: 
Expression)
   override def dataType: DataType = LongType
 
   override def inputTypes: Seq[AbstractDataType] = {
-    val elementType = left.dataType match {
-      case t: ArrayType => t.elementType
-      case _ => AnyDataType
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
     }
-    Seq(ArrayType, elementType)
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case f: TypeCheckResult.TypeCheckFailure => f
-      case TypeCheckResult.TypeCheckSuccess =>
-        TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+        TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
+      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function 
$prettyName should have " +
+        s"been ${ArrayType.simpleString} followed by a value with same element 
type, but it's " +
+        s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/51d5378f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ad52fd0..fd71f24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1097,18 +1097,63 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
 
     checkAnswer(
-      df.selectExpr("array_position(array(array(1), null)[0], 1)"),
-      Seq(Row(1L), Row(1L))
+      OneRowRelation().selectExpr("array_position(array(1), 1.23D)"),
+      Seq(Row(0L))
     )
+
     checkAnswer(
-      df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
-      Seq(Row(1L), Row(1L))
+      OneRowRelation().selectExpr("array_position(array(1), 1.0D)"),
+      Seq(Row(1L))
     )
 
-    val e = intercept[AnalysisException] {
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1.D), 1)"),
+      Seq(Row(1L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1.23D), 1)"),
+      Seq(Row(0L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1)), 
array(1.0D))"),
+      Seq(Row(1L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1)), 
array(1.23D))"),
+      Seq(Row(0L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 
1)"),
+      Seq(Row(1L))
+    )
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1, null), array(1, 
null)[0])"),
+      Seq(Row(1L))
+    )
+
+    val e1 = intercept[AnalysisException] {
       Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, 
_2)")
     }
-    assert(e.message.contains("argument 1 requires array type, however, '`_1`' 
is of string type"))
+    val errorMsg1 =
+      s"""
+         |Input to function array_position should have been array followed by a
+         |value with same element type, but it's [string, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e1.message.contains(errorMsg1))
+
+    val e2 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("array_position(array(1), '1')")
+    }
+    val errorMsg2 =
+      s"""
+         |Input to function array_position should have been array followed by a
+         |value with same element type, but it's [array<int>, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e2.message.contains(errorMsg2))
   }
 
   test("element_at function") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to