This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 861df43e8d0 [SPARK-39419][SQL][3.2] Fix ArraySort to throw an 
exception when the comparator returns null
861df43e8d0 is described below

commit 861df43e8d022f51727e0a12a7cca5e119e3c4cc
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Jun 10 16:50:20 2022 -0700

    [SPARK-39419][SQL][3.2] Fix ArraySort to throw an exception when the 
comparator returns null
    
    ### What changes were proposed in this pull request?
    
    Backport of #36812.
    
    Fixes `ArraySort` to throw an exception when the comparator returns `null`.
    
    Also updates the doc to follow the corrected behavior.
    
    ### Why are the changes needed?
    
    When the comparator of `ArraySort` returns `null`, currently it handles it 
as `0` (equal).
    
    According to the doc,
    
    ```
    It returns -1, 0, or 1 as the first element is less than, equal to, or 
greater than
    the second element. If the comparator function returns other
    values (including null), the function will fail and raise an error.
    ```
    
    It's fine to return non -1, 0, 1 integers to follow the Java convention 
(still need to update the doc, though), but it should throw an exception for 
`null` result.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, if a user uses a comparator that returns `null`, it will throw an 
error after this PR.
    
    The legacy flag `spark.sql.legacy.allowNullComparisonResultInArraySort` can 
be used to restore the legacy behavior that handles `null` as `0` (equal).
    
    ### How was this patch tested?
    
    Added some tests.
    
    Closes #36835 from ueshin/issues/SPARK-39419/3.2/array_sort.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 core/src/main/resources/error/error-classes.json   |  3 +++
 .../expressions/higherOrderFunctions.scala         | 26 +++++++++++++++++-----
 .../spark/sql/errors/QueryExecutionErrors.scala    |  5 +++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 10 +++++++++
 .../expressions/HigherOrderFunctionsSuite.scala    | 22 +++++++++++++++++-
 5 files changed, 60 insertions(+), 6 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 9ac5f06a225..9999eb5f6e4 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -49,6 +49,9 @@
     "message" : [ "PARTITION clause cannot contain a non-partition column 
name: %s" ],
     "sqlState" : "42000"
   },
+  "NULL_COMPARISON_RESULT" : {
+    "message" : [ "The comparison result is null. If you want to handle null 
as 0 (equal), you can set 
\"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"." ]
+  },
   "PIVOT_VALUE_DATA_TYPE_MISMATCH" : {
     "message" : [ "Invalid pivot value '%s': value data type %s does not match 
pivot column data type %s" ],
     "sqlState" : "42000"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 0ec817836a5..da7f371c17b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -357,9 +357,9 @@ case class ArrayTransform(
     Since 3.0.0 this function also sorts and returns the array based on the
     given comparator function. The comparator will take two arguments 
representing
     two elements of the array.
-    It returns -1, 0, or 1 as the first element is less than, equal to, or 
greater
-    than the second element. If the comparator function returns other
-    values (including null), the function will fail and raise an error.
+    It returns a negative integer, 0, or a positive integer as the first 
element is less than,
+    equal to, or greater than the second element. If the comparator function 
returns null,
+    the function will fail and raise an error.
     """,
   examples = """
     Examples:
@@ -375,9 +375,17 @@ case class ArrayTransform(
 // scalastyle:on line.size.limit
 case class ArraySort(
     argument: Expression,
-    function: Expression)
+    function: Expression,
+    allowNullComparisonResult: Boolean)
   extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
 
+  def this(argument: Expression, function: Expression) = {
+    this(
+      argument,
+      function,
+      
SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT))
+  }
+
   def this(argument: Expression) = this(argument, ArraySort.defaultComparator)
 
   @transient lazy val elementType: DataType =
@@ -421,7 +429,11 @@ case class ArraySort(
     (o1: Any, o2: Any) => {
       firstElemVar.value.set(o1)
       secondElemVar.value.set(o2)
-      f.eval(inputRow).asInstanceOf[Int]
+      val cmp = f.eval(inputRow)
+      if (!allowNullComparisonResult && cmp == null) {
+        throw QueryExecutionErrors.nullComparisonResultError()
+      }
+      cmp.asInstanceOf[Int]
     }
   }
 
@@ -442,6 +454,10 @@ case class ArraySort(
 
 object ArraySort {
 
+  def apply(argument: Expression, function: Expression): ArraySort = {
+    new ArraySort(argument, function)
+  }
+
   def comparator(left: Expression, right: Expression): Expression = {
     val lit0 = Literal(0)
     val lit1 = Literal(1)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index c5ac476aa31..3c922dec29d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1816,4 +1816,9 @@ private[sql] object QueryExecutionErrors {
         s". To solve this try to set $maxDynamicPartitionsKey" +
         s" to at least $numWrittenParts.")
   }
+
+  def nullComparisonResultError(): Throwable = {
+    new SparkException(errorClass = "NULL_COMPARISON_RESULT",
+      messageParameters = Array(), cause = null)
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index f2e41845908..33765619823 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3405,6 +3405,16 @@ object SQLConf {
     .intConf
     .createWithDefault(0)
 
+  val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT =
+    buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort")
+      .internal()
+      .doc("When set to false, `array_sort` function throws an error " +
+        "if the comparator function returns null. " +
+        "If set to true, it restores the legacy behavior that handles null as 
zero (equal).")
+      .version("3.2.2")
+      .booleanConf
+      .createWithDefault(false)
+
   /**
    * Holds information about keys that have been deprecated.
    *
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index c0db6d8dc29..b1c4c441427 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
       Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
       Seq(1d, 2d, Double.NaN, null))
   }
+
+  test("SPARK-39419: ArraySort should throw an exception when the comparator 
returns null") {
+    val comparator = {
+      val comp = ArraySort.comparator _
+      (left: Expression, right: Expression) =>
+        If(comp(left, right) === 0, Literal.create(null, IntegerType), 
comp(left, right))
+    }
+
+    withSQLConf(
+        SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> 
"false") {
+      checkExceptionInExpression[SparkException](
+        arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The 
comparison result is null")
+    }
+
+    withSQLConf(
+        SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> 
"true") {
+      checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator),
+        Seq(1, 1, 2, 3))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to