This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3e4a32d1d4a [SPARK-41233][FOLLOWUP] Refactor `array_prepend` with
`RuntimeReplaceable`
3e4a32d1d4a is described below
commit 3e4a32d1d4a0bace21651a80203cda2c3f2f3b68
Author: jiaan Geng <[email protected]>
AuthorDate: Mon Apr 24 18:30:03 2023 +0800
[SPARK-41233][FOLLOWUP] Refactor `array_prepend` with `RuntimeReplaceable`
### What changes were proposed in this pull request?
Recently, Spark SQL supported `array_insert` and `array_prepend`. All
implementations are individual.
In fact, `array_prepend` is special case of `array_insert` and we can reuse
the `array_insert` by extends `RuntimeReplaceable`.
### Why are the changes needed?
Simplify the implementation of `array_prepend`.
### Does this PR introduce _any_ user-facing change?
'No'.
Just update the inner implementation.
### How was this patch tested?
Exists test case.
Closes #40563 from beliefer/SPARK-41232_SPARK-41233_followup.
Lead-authored-by: jiaan Geng <[email protected]>
Co-authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../explain-results/function_array_prepend.explain | 2 +-
.../expressions/collectionOperations.scala | 118 ++++-----------------
.../expressions/CollectionExpressionsSuite.scala | 44 --------
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 16 +++
4 files changed, 36 insertions(+), 144 deletions(-)
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
index 539e1eaf767..4c3e7c85d64 100644
---
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
@@ -1,2 +1,2 @@
-Project [array_prepend(e#0, 1) AS array_prepend(e, 1)#0]
+Project [array_insert(e#0, 1, 1) AS array_prepend(e, 1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index beed5a6e365..63060e61d56 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1397,7 +1397,6 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
-// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array, element) - Add the element at the beginning of the array
passed as first
@@ -1416,101 +1415,26 @@ case class ArrayContains(left: Expression, right:
Expression)
""",
group = "array_funcs",
since = "3.5.0")
-case class ArrayPrepend(left: Expression, right: Expression)
- extends BinaryExpression
- with ImplicitCastInputTypes
- with ComplexTypeMergingExpression
- with QueryErrorsBase {
+case class ArrayPrepend(left: Expression, right: Expression) extends
RuntimeReplaceable
+ with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase
{
- override def nullable: Boolean = left.nullable
+ override lazy val replacement: Expression = ArrayInsert(left, Literal(1),
right)
- @transient protected lazy val elementType: DataType =
- inputTypes.head.asInstanceOf[ArrayType].elementType
-
- override def eval(input: InternalRow): Any = {
- val value1 = left.eval(input)
- if (value1 == null) {
- null
- } else {
- val value2 = right.eval(input)
- nullSafeEval(value1, value2)
- }
- }
- override def nullSafeEval(arr: Any, elementData: Any): Any = {
- val arrayData = arr.asInstanceOf[ArrayData]
- val numberOfElements = arrayData.numElements() + 1
- if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
- }
- val finalData = new Array[Any](numberOfElements)
- finalData.update(0, elementData)
- arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1,
v))
- new GenericArrayData(finalData)
- }
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
- val leftGen = left.genCode(ctx)
- val rightGen = right.genCode(ctx)
- val f = (arr: String, value: String) => {
- val newArraySize = s"$arr.numElements() + 1"
- val newArray = ctx.freshName("newArray")
- val i = ctx.freshName("i")
- val iPlus1 = s"$i+1"
- val zero = "0"
- val allocation = CodeGenerator.createArrayData(
- newArray,
- elementType,
- newArraySize,
- s" $prettyName failed.")
- val assignment =
- CodeGenerator.createArrayAssignment(newArray, elementType, arr,
iPlus1, i, false)
- val newElemAssignment =
- CodeGenerator.setArrayElement(newArray, elementType, zero, value,
Some(rightGen.isNull))
- s"""
- |$allocation
- |$newElemAssignment
- |for (int $i = 0; $i < $arr.numElements(); $i ++) {
- | $assignment
- |}
- |${ev.value} = $newArray;
- |""".stripMargin
- }
- val resultCode = f(leftGen.value, rightGen.value)
- if(nullable) {
- val nullSafeEval = leftGen.code + rightGen.code +
ctx.nullSafeExec(nullable, leftGen.isNull) {
- s"""
- |${ev.isNull} = false;
- |${resultCode}
- |""".stripMargin
- }
- ev.copy(code =
- code"""
- |boolean ${ev.isNull} = true;
- |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
- |$nullSafeEval
- """.stripMargin
- )
- } else {
- ev.copy(code =
- code"""
- |${leftGen.code}
- |${rightGen.code}
- |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
- |$resultCode
- """.stripMargin, isNull = FalseLiteral)
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
}
}
- override def prettyName: String = "array_prepend"
-
- override protected def withNewChildrenInternal(
- newLeft: Expression, newRight: Expression): ArrayPrepend =
- copy(left = newLeft, right = newRight)
-
- override def dataType: DataType = if (right.nullable)
left.dataType.asNullable else left.dataType
-
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
- case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeCheckResult.TypeCheckSuccess
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+ TypeCheckResult.TypeCheckSuccess
case (ArrayType(e1, _), e2) => DataTypeMismatch(
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
messageParameters = Map(
@@ -1531,16 +1455,12 @@ case class ArrayPrepend(left: Expression, right:
Expression)
)
}
}
- override def inputTypes: Seq[AbstractDataType] = {
- (left.dataType, right.dataType) match {
- case (ArrayType(e1, hasNull), e2) =>
- TypeCoercion.findTightestCommonType(e1, e2) match {
- case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
- case _ => Seq.empty
- }
- case _ => Seq.empty
- }
- }
+
+ override def prettyName: String = "array_prepend"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): ArrayPrepend =
+ copy(left = newLeft, right = newRight)
}
/**
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 8f1ff97a78e..485579230c0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1855,50 +1855,6 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)),
null)
}
- test("SPARK-41233: ArrayPrepend") {
- val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
- val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
- val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
- val a3 = Literal.create(null, ArrayType(StringType))
-
- checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
- checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
- checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
- checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
- checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
- checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)
-
- // complex data types
- val data = Seq[Array[Byte]](
- Array[Byte](5, 6),
- Array[Byte](1, 2),
- Array[Byte](1, 2),
- Array[Byte](5, 6))
- val b0 = Literal.create(
- data,
- ArrayType(BinaryType))
- val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
ArrayType(BinaryType))
- val nullBinary = Literal.create(null, BinaryType)
- // Calling ArrayPrepend with a null element should result in NULL being
prepended to the array
- val dataWithNullPrepended = null +: data
- checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
- val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
- checkEvaluation(
- ArrayPrepend(b1, dataToPrepend1),
- Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))
-
- val c0 = Literal.create(
- Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
- ArrayType(ArrayType(IntegerType)))
- val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
- checkEvaluation(
- ArrayPrepend(c0, dataToPrepend2),
- Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
- checkEvaluation(
- ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
- Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
- }
-
test("Array remove") {
val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
val a1 = Literal.create(Seq("b", "a", "a", "c", "b"),
ArrayType(StringType))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 09812194ba7..5386150c8a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2717,6 +2717,22 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSparkSession {
).toDF("a", "b")
checkAnswer(df2.selectExpr("array_prepend(a, b)"),
Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y",
"z")), Row(null)))
+ val dataA = Seq[Array[Byte]](
+ Array[Byte](5, 6),
+ Array[Byte](1, 2),
+ Array[Byte](1, 2),
+ Array[Byte](5, 6))
+ val dataB = Seq[Array[Int]](Array[Int](1, 2), Array[Int](3, 4))
+ val df3 = Seq((dataA, dataB)).toDF("a", "b")
+ val dataToPrepend = Array[Byte](5, 6)
+ checkAnswer(
+ df3.select(array_prepend($"a", null), array_prepend($"a",
dataToPrepend)),
+ Seq(Row(null +: dataA, dataToPrepend +: dataA)))
+ checkAnswer(
+ df3.select(array_prepend($"b", Array.empty[Int]), array_prepend($"b",
Array[Int](5, 6))),
+ Seq(Row(
+ Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)),
+ Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))))
}
test("array remove") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]