This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e9574c54f1 [SPARK-43011][SQL] `array_insert` should fail with 0 index
3e9574c54f1 is described below

commit 3e9574c54f149b13ca768c0930c634eb67ea14c8
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Apr 4 10:22:16 2023 +0300

    [SPARK-43011][SQL] `array_insert` should fail with 0 index
    
    ### What changes were proposed in this pull request?
    Make `array_insert` fail when input index `pos` is zero.
    
    ### Why are the changes needed?
    see https://github.com/apache/spark/pull/40563#discussion_r1155673089
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    updated UT
    
    Closes #40641 from zhengruifeng/sql_array_insert_fails_zero.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json     | 12 ++++++------
 docs/sql-error-conditions-sqlstates.md               |  2 +-
 docs/sql-error-conditions.md                         | 12 ++++++------
 .../catalyst/expressions/collectionOperations.scala  | 18 ++++++++++++++----
 .../spark/sql/errors/QueryExecutionErrors.scala      |  4 ++--
 .../expressions/CollectionExpressionsSuite.scala     |  1 -
 .../resources/sql-tests/results/ansi/array.sql.out   | 17 ++++++++++++++---
 .../sql-tests/results/ansi/string-functions.sql.out  |  2 +-
 .../sql-tests/results/ansi/try_element_at.sql.out    |  2 +-
 .../test/resources/sql-tests/results/array.sql.out   | 17 ++++++++++++++---
 .../sql-tests/results/string-functions.sql.out       |  2 +-
 .../sql-tests/results/try_element_at.sql.out         |  2 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala   | 20 +++++++++++++++-----
 .../sql/errors/QueryExecutionAnsiErrorsSuite.scala   |  4 ++--
 14 files changed, 78 insertions(+), 37 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 8369c7c5666..d330ea09f30 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -536,12 +536,6 @@
     ],
     "sqlState" : "23505"
   },
-  "ELEMENT_AT_BY_INDEX_ZERO" : {
-    "message" : [
-      "The index 0 is invalid. An index shall be either < 0 or > 0 (the first 
element has index 1)."
-    ],
-    "sqlState" : "22003"
-  },
   "EMPTY_JSON_FIELD_VALUE" : {
     "message" : [
       "Failed to parse an empty string for data type <dataType>."
@@ -915,6 +909,12 @@
     ],
     "sqlState" : "42602"
   },
+  "INVALID_INDEX_OF_ZERO" : {
+    "message" : [
+      "The index 0 is invalid. An index shall be either < 0 or > 0 (the first 
element has index 1)."
+    ],
+    "sqlState" : "22003"
+  },
   "INVALID_JSON_ROOT_FIELD" : {
     "message" : [
       "Cannot convert JSON root field to target Spark type."
diff --git a/docs/sql-error-conditions-sqlstates.md 
b/docs/sql-error-conditions-sqlstates.md
index 1eea335ac9b..6b4c7e62f71 100644
--- a/docs/sql-error-conditions-sqlstates.md
+++ b/docs/sql-error-conditions-sqlstates.md
@@ -71,7 +71,7 @@ Spark SQL uses the following `SQLSTATE` classes:
 </tr>
 <tr>
   <td></td>
-  <td><a href="arithmetic-overflow-error-class.md">ARITHMETIC_OVERFLOW</a>, <a 
href="sql-error-conditions.html#cast_overflow">CAST_OVERFLOW</a>, <a 
href="sql-error-conditions.html#cast_overflow_in_table_insert">CAST_OVERFLOW_IN_TABLE_INSERT</a>,
 <a 
href="sql-error-conditions.html#decimal_precision_exceeds_max_precision">DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION</a>,
 <a 
href="sql-error-conditions.html#element_at_by_index_zero">ELEMENT_AT_BY_INDEX_ZERO</a>,
 <a href="sql-error-conditions.html [...]
+  <td><a href="arithmetic-overflow-error-class.md">ARITHMETIC_OVERFLOW</a>, <a 
href="sql-error-conditions.html#cast_overflow">CAST_OVERFLOW</a>, <a 
href="sql-error-conditions.html#cast_overflow_in_table_insert">CAST_OVERFLOW_IN_TABLE_INSERT</a>,
 <a 
href="sql-error-conditions.html#decimal_precision_exceeds_max_precision">DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION</a>,
 <a 
href="sql-error-conditions.html#invalid_index_of_zero">INVALID_INDEX_OF_ZERO</a>,
 <a href="sql-error-conditions.html#incor [...]
   </td>
 </tr>
     <tr>
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 027daccd3e0..6075d35dde9 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -289,12 +289,6 @@ Duplicate map key `<key>` was found, please check the 
input data. If you want to
 
 Found duplicate keys `<keyColumn>`.
 
-### ELEMENT_AT_BY_INDEX_ZERO
-
-[SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception)
-
-The index 0 is invalid. An index shall be either < 0 or > 0 (the first element 
has index 1).
-
 ### EMPTY_JSON_FIELD_VALUE
 
 [SQLSTATE: 
42604](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
@@ -573,6 +567,12 @@ The fraction of sec must be zero. Valid range is [0, 60]. 
If necessary set `<ans
 
 The identifier `<ident>` is invalid. Please, consider quoting it with 
back-quotes as ``<ident>``.
 
+### INVALID_INDEX_OF_ZERO
+
+[SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception)
+
+The index 0 is invalid. An index shall be either < 0 or > 0 (the first element 
has index 1).
+
 ### INVALID_JSON_ROOT_FIELD
 
 [SQLSTATE: 22032](sql-error-conditions-sqlstates.html#class-22-data-exception)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index adeccb3ec7e..58d18af4d0c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2489,7 +2489,7 @@ case class ElementAt(
           }
         } else {
           val idx = if (index == 0) {
-            throw 
QueryExecutionErrors.elementAtByIndexZeroError(getContextOrNull())
+            throw 
QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
           } else if (index > 0) {
             index - 1
           } else {
@@ -2544,7 +2544,7 @@ case class ElementAt(
              |  $indexOutOfBoundBranch
              |} else {
              |  if ($index == 0) {
-             |    throw 
QueryExecutionErrors.elementAtByIndexZeroError($errorContext);
+             |    throw 
QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
              |  } else if ($index > 0) {
              |    $index--;
              |  } else {
@@ -4767,7 +4767,7 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
   since = "3.4.0")
 case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, 
itemExpr: Expression)
   extends TernaryExpression with ImplicitCastInputTypes with 
ComplexTypeMergingExpression
-    with QueryErrorsBase {
+    with QueryErrorsBase with SupportQueryContext {
 
   override def inputTypes: Seq[AbstractDataType] = {
     (srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match {
@@ -4820,8 +4820,11 @@ case class ArrayInsert(srcArrayExpr: Expression, 
posExpr: Expression, itemExpr:
   }
 
   override def nullSafeEval(arr: Any, pos: Any, item: Any): Any = {
-    val baseArr = arr.asInstanceOf[ArrayData]
     var posInt = pos.asInstanceOf[Int]
+    if (posInt == 0) {
+      throw QueryExecutionErrors.invalidIndexOfZeroError(getContextOrNull())
+    }
+    val baseArr = arr.asInstanceOf[ArrayData]
     val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
 
     val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > 
baseArr.numElements())
@@ -4895,6 +4898,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: 
Expression, itemExpr:
         values, elementType, resLength, s"$prettyName failed.")
       val assignment = CodeGenerator.createArrayAssignment(values, 
elementType, arr,
         adjustedAllocIdx, i, 
first.dataType.asInstanceOf[ArrayType].containsNull)
+      val errorContext = getContextOrNullCode(ctx)
 
       s"""
          |int $itemInsertionIndex = 0;
@@ -4902,6 +4906,10 @@ case class ArrayInsert(srcArrayExpr: Expression, 
posExpr: Expression, itemExpr:
          |int $adjustedAllocIdx = 0;
          |boolean $insertedItemIsNull = ${itemExpr.isNull};
          |
+         |if ($pos == 0) {
+         |  throw QueryExecutionErrors.invalidIndexOfZeroError($errorContext);
+         |}
+         |
          |if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
          |
          |  $resLength = java.lang.Math.abs($pos) + 1;
@@ -5002,6 +5010,8 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: 
Expression, itemExpr:
   override protected def withNewChildrenInternal(
       newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: 
Expression): ArrayInsert =
     copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = 
newItemExpr)
+
+  override def initQueryContext(): Option[SQLQueryContext] = 
Some(origin.context)
 }
 
 @ExpressionDescription(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 17c5b2f4f10..7ec9f41af36 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1605,9 +1605,9 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
         "prettyName" -> prettyName))
   }
 
-  def elementAtByIndexZeroError(context: SQLQueryContext): RuntimeException = {
+  def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = {
     new SparkRuntimeException(
-      errorClass = "ELEMENT_AT_BY_INDEX_ZERO",
+      errorClass = "INVALID_INDEX_OF_ZERO",
       cause = null,
       messageParameters = Map.empty,
       context = getQueryContext(context),
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 3abc70a3d55..8f1ff97a78e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2356,7 +2356,6 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
 
     // index edge cases
     checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
-    checkEvaluation(ArrayInsert(a1, Literal(0), Literal(3)), Seq(3, 1, 2, 4))
     checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
     checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
     checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4))
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index d228c605705..91294fffe04 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -248,7 +248,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003",
   "queryContext" : [ {
     "objectType" : "",
@@ -561,9 +561,20 @@ struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
 -- !query
 select array_insert(array(2, 3, 4), 0, 1)
 -- !query schema
-struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
+struct<>
 -- !query output
-[1,2,3,4]
+org.apache.spark.SparkRuntimeException
+{
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
+  "sqlState" : "22003",
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 41,
+    "fragment" : "array_insert(array(2, 3, 4), 0, 1)"
+  } ]
+}
 
 
 -- !query
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index 837a79c92b6..3e72daf6dfc 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -252,7 +252,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003"
 }
 
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/try_element_at.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/try_element_at.sql.out
index 0a518fcaf11..0437f9d6dd9 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/try_element_at.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_element_at.sql.out
@@ -6,7 +6,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003"
 }
 
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out 
b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 029bd767f54..58d25f674ea 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -216,7 +216,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003"
 }
 
@@ -442,9 +442,20 @@ struct<array_insert(array(1, 2, 3), 3, 4):array<int>>
 -- !query
 select array_insert(array(2, 3, 4), 0, 1)
 -- !query schema
-struct<array_insert(array(2, 3, 4), 0, 1):array<int>>
+struct<>
 -- !query output
-[1,2,3,4]
+org.apache.spark.SparkRuntimeException
+{
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
+  "sqlState" : "22003",
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 41,
+    "fragment" : "array_insert(array(2, 3, 4), 0, 1)"
+  } ]
+}
 
 
 -- !query
diff --git 
a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 98b8505d503..e3fb58e907a 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -218,7 +218,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003"
 }
 
diff --git 
a/sql/core/src/test/resources/sql-tests/results/try_element_at.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try_element_at.sql.out
index 0a518fcaf11..0437f9d6dd9 100644
--- a/sql/core/src/test/resources/sql-tests/results/try_element_at.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/try_element_at.sql.out
@@ -6,7 +6,7 @@ struct<>
 -- !query output
 org.apache.spark.SparkRuntimeException
 {
-  "errorClass" : "ELEMENT_AT_BY_INDEX_ZERO",
+  "errorClass" : "INVALID_INDEX_OF_ZERO",
   "sqlState" : "22003"
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 355f2dfffb5..09812194ba7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -3191,17 +3191,27 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       Seq(Row(Seq[Double](3.0, 3.0, 2.0, 5.0, 1.0, 2.0)))
     )
     checkAnswer(df4.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq(true, 
false, false))))
-    checkAnswer(df5.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("d", "a", 
"b", "c"))))
+
+    val e1 = intercept[SparkException] {
+      df5.selectExpr("array_insert(a, b, c)").show()
+    }
+    assert(e1.getCause.isInstanceOf[SparkRuntimeException])
+    checkError(
+      exception = e1.getCause.asInstanceOf[SparkRuntimeException],
+      errorClass = "INVALID_INDEX_OF_ZERO",
+      parameters = Map.empty,
+      context = ExpectedContext(
+        fragment = "array_insert(a, b, c)",
+        start = 0,
+        stop = 20)
+    )
+
     checkAnswer(df5.select(
       array_insert(col("a"), lit(1), col("c"))),
       Seq(Row(Seq("d", "a", "b", "c")))
     )
     // null checks
     checkAnswer(df6.selectExpr("array_insert(a, b, c)"), Seq(Row(Seq("a", 
null, "b", "c", "d"))))
-    checkAnswer(df5.select(
-      array_insert(col("a"), col("b"), lit(null).cast("string"))),
-      Seq(Row(Seq(null, "a", "b", "c")))
-    )
     checkAnswer(df6.select(
       array_insert(col("a"), col("b"), lit(null).cast("string"))),
       Seq(Row(Seq("a", null, "b", "c", null)))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
index 45c7898dfa2..ee28a90aed9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
@@ -117,12 +117,12 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
         stop = 41))
   }
 
-  test("ELEMENT_AT_BY_INDEX_ZERO: element_at from array by index zero") {
+  test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") {
     checkError(
       exception = intercept[SparkRuntimeException](
         sql("select element_at(array(1, 2, 3, 4, 5), 0)").collect()
       ),
-      errorClass = "ELEMENT_AT_BY_INDEX_ZERO",
+      errorClass = "INVALID_INDEX_OF_ZERO",
       parameters = Map.empty,
       context = ExpectedContext(
         fragment = "element_at(array(1, 2, 3, 4, 5), 0)",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to