This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 50f20fbad36 [SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend`
50f20fbad36 is described below
commit 50f20fbad36dbb05a5132a3043364eff7b1a565c
Author: Max Gekk <[email protected]>
AuthorDate: Fri Aug 25 22:51:08 2023 -0700
[SPARK-44969][SQL] Reuse `ArrayInsert` in `ArrayAppend`
### What changes were proposed in this pull request?
In the PR, I propose to replace the current implementation of the
`ArrayAppend` expression by a runtime replaceable one to `ArrayInsert` with
`posExpr = -1`.
### Why are the changes needed?
To improve code maintenance.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By running the modified test suite:
```
$ build/sbt "test:testOnly *CollectionExpressionsSuite"
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42660 from MaxGekk/array_append-to-insert-1.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../explain-results/function_array_append.explain | 2 +-
.../expressions/collectionOperations.scala | 229 ++++++---------------
.../expressions/CollectionExpressionsSuite.scala | 144 ++++++-------
3 files changed, 119 insertions(+), 256 deletions(-)
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
index ca2804ebb60..e857e2e974f 100644
---
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_append.explain
@@ -1,2 +1,2 @@
-Project [array_append(e#0, 1) AS array_append(e, 1)#0]
+Project [array_insert(e#0, -1, 1, false) AS array_append(e, 1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index fe9c4015c15..957aa1ab2d5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1397,29 +1397,9 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
-@ExpressionDescription(
- usage = """
- _FUNC_(array, element) - Add the element at the beginning of the array
passed as first
- argument. Type of element should be the same as the type of the elements
of the array.
- Null element is also prepended to the array. But if the array passed is
NULL
- output is NULL
- """,
- examples = """
- Examples:
- > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
- ["d","b","d","c","a"]
- > SELECT _FUNC_(array(1, 2, 3, null), null);
- [null,1,2,3,null]
- > SELECT _FUNC_(CAST(null as Array<Int>), 2);
- NULL
- """,
- group = "array_funcs",
- since = "3.5.0")
-case class ArrayPrepend(left: Expression, right: Expression) extends
RuntimeReplaceable
+trait ArrayPendBase extends RuntimeReplaceable
with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase
{
- override lazy val replacement: Expression = new ArrayInsert(left,
Literal(1), right)
-
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
@@ -1455,6 +1435,29 @@ case class ArrayPrepend(left: Expression, right:
Expression) extends RuntimeRepl
)
}
}
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, element) - Add the element at the beginning of the array
passed as first
+ argument. Type of element should be the same as the type of the elements
of the array.
+ Null element is also prepended to the array. But if the array passed is
NULL
+ output is NULL
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+ ["d","b","d","c","a"]
+ > SELECT _FUNC_(array(1, 2, 3, null), null);
+ [null,1,2,3,null]
+ > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+ NULL
+ """,
+ group = "array_funcs",
+ since = "3.5.0")
+case class ArrayPrepend(left: Expression, right: Expression) extends
ArrayPendBase {
+
+ override lazy val replacement: Expression = new ArrayInsert(left,
Literal(1), right)
override def prettyName: String = "array_prepend"
@@ -1463,6 +1466,41 @@ case class ArrayPrepend(left: Expression, right:
Expression) extends RuntimeRepl
copy(left = newLeft, right = newRight)
}
+
+/**
+ * Given an array, and another element append the element at the end of the
array.
+ * This function does not return null when the elements are null. It appends
null at
+ * the end of the array. But returns null if the array passed is null.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, element) - Add the element at the end of the array passed
as first
+ argument. Type of element should be similar to type of the elements of
the array.
+ Null element is also appended into the array. But if the array passed,
is NULL
+ output is NULL
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+ ["b","d","c","a","d"]
+ > SELECT _FUNC_(array(1, 2, 3, null), null);
+ [1,2,3,null,null]
+ > SELECT _FUNC_(CAST(null as Array<Int>), 2);
+ NULL
+ """,
+ since = "3.4.0",
+ group = "array_funcs")
+case class ArrayAppend(left: Expression, right: Expression) extends
ArrayPendBase {
+
+ override lazy val replacement: Expression = new ArrayInsert(left,
Literal(-1), right)
+
+ override def prettyName: String = "array_append"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): ArrayAppend =
+ copy(left = newLeft, right = newRight)
+}
+
/**
* Checks if the two arrays contain at least one common element.
*/
@@ -5039,152 +5077,3 @@ case class ArrayCompact(child: Expression)
override protected def withNewChildInternal(newChild: Expression):
ArrayCompact =
copy(child = newChild)
}
-
-/**
- * Given an array, and another element append the element at the end of the
array.
- * This function does not return null when the elements are null. It appends
null at
- * the end of the array. But returns null if the array passed is null.
- */
-@ExpressionDescription(
- usage = """
- _FUNC_(array, element) - Add the element at the end of the array passed
as first
- argument. Type of element should be similar to type of the elements of
the array.
- Null element is also appended into the array. But if the array passed,
is NULL
- output is NULL
- """,
- examples = """
- Examples:
- > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
- ["b","d","c","a","d"]
- > SELECT _FUNC_(array(1, 2, 3, null), null);
- [1,2,3,null,null]
- > SELECT _FUNC_(CAST(null as Array<Int>), 2);
- NULL
- """,
- since = "3.4.0",
- group = "array_funcs")
-case class ArrayAppend(left: Expression, right: Expression)
- extends BinaryExpression
- with ImplicitCastInputTypes
- with ComplexTypeMergingExpression
- with QueryErrorsBase {
- override def prettyName: String = "array_append"
-
- @transient protected lazy val elementType: DataType =
- inputTypes.head.asInstanceOf[ArrayType].elementType
-
- override def inputTypes: Seq[AbstractDataType] = {
- (left.dataType, right.dataType) match {
- case (ArrayType(e1, hasNull), e2) =>
- TypeCoercion.findTightestCommonType(e1, e2) match {
- case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
- case _ => Seq.empty
- }
- case _ => Seq.empty
- }
- }
-
- override def checkInputDataTypes(): TypeCheckResult = {
- (left.dataType, right.dataType) match {
- case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeCheckResult.TypeCheckSuccess
- case (ArrayType(e1, _), e2) => DataTypeMismatch(
- errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
- messageParameters = Map(
- "functionName" -> toSQLId(prettyName),
- "leftType" -> toSQLType(left.dataType),
- "rightType" -> toSQLType(right.dataType),
- "dataType" -> toSQLType(ArrayType)
- ))
- case _ =>
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "0",
- "requiredType" -> toSQLType(ArrayType),
- "inputSql" -> toSQLExpr(left),
- "inputType" -> toSQLType(left.dataType)
- )
- )
- }
- }
-
- override def eval(input: InternalRow): Any = {
- val value1 = left.eval(input)
- if (value1 == null) {
- null
- } else {
- val value2 = right.eval(input)
- nullSafeEval(value1, value2)
- }
- }
-
- override protected def nullSafeEval(arr: Any, elementData: Any): Any = {
- val arrayData = arr.asInstanceOf[ArrayData]
- val numberOfElements = arrayData.numElements() + 1
- if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
- }
- val finalData = new Array[Any](numberOfElements)
- arrayData.foreach(elementType, finalData.update)
- finalData.update(arrayData.numElements(), elementData)
- new GenericArrayData(finalData)
- }
-
- override def nullable: Boolean = left.nullable
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val leftGen = left.genCode(ctx)
- val rightGen = right.genCode(ctx)
- val f = (eval1: String, eval2: String) => {
- val newArraySize = ctx.freshName("newArraySize")
- val i = ctx.freshName("i")
- val values = ctx.freshName("values")
- val allocation = CodeGenerator.createArrayData(
- values, elementType, newArraySize, s" $prettyName failed.")
- val assignment = CodeGenerator.createArrayAssignment(
- values, elementType, eval1, i, i,
left.dataType.asInstanceOf[ArrayType].containsNull)
- s"""
- |int $newArraySize = $eval1.numElements() + 1;
- |$allocation
- |int $i = 0;
- |while ($i < $eval1.numElements()) {
- | $assignment
- | $i ++;
- |}
- |${CodeGenerator.setArrayElement(values, elementType, i, eval2,
Some(rightGen.isNull))}
- |${ev.value} = $values;
- |""".stripMargin
- }
- val resultCode = f(leftGen.value, rightGen.value)
- if (nullable) {
- val nullSafeEval =
- leftGen.code + rightGen.code + ctx.nullSafeExec(left.nullable,
leftGen.isNull) {
- s"""
- ${ev.isNull} = false; // resultCode could change nullability.
- $resultCode
- """
- }
- ev.copy(code =
- code"""
- boolean ${ev.isNull} = true;
- ${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
- $nullSafeEval
- """)
- } else {
- ev.copy(code =
- code"""
- ${leftGen.code}
- ${rightGen.code}
- ${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
- $resultCode""", isNull = FalseLiteral)
- }
- }
-
- /**
- * Returns the [[DataType]] of the result of evaluating this expression. It
is invalid to query
- * the dataType of an unresolved expression (i.e., when `resolved` == false).
- */
- override def dataType: DataType = if (right.nullable)
left.dataType.asNullable else left.dataType
- protected def withNewChildrenInternal(newLeft: Expression, newRight:
Expression): ArrayAppend =
- copy(left = newLeft, right = newRight)
-
-}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 1787f6ac72d..ff393857c31 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2328,6 +2328,27 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
)
// null handling
+ checkEvaluation(
+ ArrayInsert(
+ Literal.create(null, ArrayType(StringType)),
+ Literal(-1),
+ Literal.create("c", StringType),
+ legacyNegativeIndex = false),
+ null)
+ checkEvaluation(
+ ArrayInsert(
+ Literal.create(null, ArrayType(StringType)),
+ Literal(-1),
+ Literal.create(null, StringType),
+ legacyNegativeIndex = false),
+ null)
+ checkEvaluation(
+ ArrayInsert(
+ Literal.create(Seq(""), ArrayType(StringType)),
+ Literal(-1),
+ Literal.create(null, StringType),
+ legacyNegativeIndex = false),
+ Seq("", null))
checkEvaluation(new ArrayInsert(
a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4)
)
@@ -2336,6 +2357,38 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
Seq("b", null, "d", "a", "g", null))
checkEvaluation(new ArrayInsert(a11, Literal(3), Literal("d")), null)
checkEvaluation(new ArrayInsert(a10, Literal.create(null, IntegerType),
Literal("d")), null)
+
+ assert(
+ ArrayInsert(
+ Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
+ Literal(-1),
+ Literal.create(3, IntegerType),
+ legacyNegativeIndex = false)
+ .checkInputDataTypes() ==
+ DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> "`array_insert`",
+ "dataType" -> "\"ARRAY\"",
+ "leftType" -> "\"ARRAY<DOUBLE>\"",
+ "rightType" -> "\"INT\""))
+ )
+
+ assert(
+ ArrayInsert(
+ Literal.create("Hi", StringType),
+ Literal(-1),
+ Literal.create("Spark", StringType),
+ legacyNegativeIndex = false)
+ .checkInputDataTypes() == DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> "`array_insert`",
+ "dataType" -> "\"ARRAY\"",
+ "leftType" -> "\"STRING\"",
+ "rightType" -> "\"STRING\"")
+ )
+ )
}
test("Array Intersect") {
@@ -2679,86 +2732,14 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
}
}
- test("ArrayAppend Expression Test") {
- checkEvaluation(
- ArrayAppend(
- Literal.create(null, ArrayType(StringType)),
- Literal.create("c", StringType)),
- null)
-
- checkEvaluation(
- ArrayAppend(
- Literal.create(null, ArrayType(StringType)),
- Literal.create(null, StringType)),
- null)
-
- checkEvaluation(
- ArrayAppend(
- Literal.create(Seq(""), ArrayType(StringType)),
- Literal.create(null, StringType)),
- Seq("", null))
-
- checkEvaluation(
- ArrayAppend(
- Literal.create(Seq("a", "b", "c"), ArrayType(StringType)),
- Literal.create(null, StringType)),
- Seq("a", "b", "c", null))
-
- checkEvaluation(
- ArrayAppend(
- Literal.create(Seq(Double.NaN, 1d, 2d), ArrayType(DoubleType)),
- Literal.create(3d, DoubleType)),
- Seq(Double.NaN, 1d, 2d, 3d))
- // Null entry check
- checkEvaluation(
- ArrayAppend(
- Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
- Literal.create(3d, DoubleType)),
- Seq(null, 1d, 2d, 3d))
-
- checkEvaluation(
- ArrayAppend(
- Literal.create(Seq("a", "b", "c"), ArrayType(StringType)),
- Literal.create("c", StringType)),
- Seq("a", "b", "c", "c"))
-
- assert(
- ArrayAppend(
- Literal.create(Seq(null, 1d, 2d), ArrayType(DoubleType)),
- Literal.create(3, IntegerType))
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
- messageParameters = Map(
- "functionName" -> "`array_append`",
- "dataType" -> "\"ARRAY\"",
- "leftType" -> "\"ARRAY<DOUBLE>\"",
- "rightType" -> "\"INT\""))
- )
-
-
- assert(
- ArrayAppend(
- Literal.create("Hi", StringType),
- Literal.create("Spark", StringType))
- .checkInputDataTypes() == DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "0",
- "requiredType" -> "\"ARRAY\"",
- "inputSql" -> "\"Hi\"",
- "inputType" -> "\"STRING\""
- )
- )
- )
-
- }
-
test("SPARK-42401: Array insert of null value (explicit)") {
val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
- checkEvaluation(new ArrayInsert(
- a, Literal(2), Literal.create(null, StringType)), Seq("b", null, "a",
"c")
- )
+ checkEvaluation(
+ new ArrayInsert(a, Literal(2), Literal.create(null, StringType)),
+ Seq("b", null, "a", "c"))
+ checkEvaluation(
+ new ArrayInsert(a, Literal(-1), Literal.create(null, StringType)),
+ Seq("b", "a", "c", null))
}
test("SPARK-42401: Array insert of null value (implicit)") {
@@ -2767,11 +2748,4 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c",
null, "q")
)
}
-
- test("SPARK-42401: Array append of null value") {
- val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
- checkEvaluation(ArrayAppend(
- a, Literal.create(null, StringType)), Seq("b", "a", "c", null)
- )
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]