This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e4a32d1d4a [SPARK-41233][FOLLOWUP] Refactor `array_prepend` with 
`RuntimeReplaceable`
3e4a32d1d4a is described below

commit 3e4a32d1d4a0bace21651a80203cda2c3f2f3b68
Author: jiaan Geng <[email protected]>
AuthorDate: Mon Apr 24 18:30:03 2023 +0800

    [SPARK-41233][FOLLOWUP] Refactor `array_prepend` with `RuntimeReplaceable`
    
    ### What changes were proposed in this pull request?
    Recently, Spark SQL supported `array_insert` and `array_prepend`. All 
implementations are individual.
    In fact, `array_prepend` is special case of `array_insert` and we can reuse 
the `array_insert` by extends `RuntimeReplaceable`.
    
    ### Why are the changes needed?
    Simplify the implementation of `array_prepend`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just update the inner implementation.
    
    ### How was this patch tested?
    Exists test case.
    
    Closes #40563 from beliefer/SPARK-41232_SPARK-41233_followup.
    
    Lead-authored-by: jiaan Geng <[email protected]>
    Co-authored-by: Jiaan Geng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../explain-results/function_array_prepend.explain |   2 +-
 .../expressions/collectionOperations.scala         | 118 ++++-----------------
 .../expressions/CollectionExpressionsSuite.scala   |  44 --------
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  16 +++
 4 files changed, 36 insertions(+), 144 deletions(-)

diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
index 539e1eaf767..4c3e7c85d64 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
@@ -1,2 +1,2 @@
-Project [array_prepend(e#0, 1) AS array_prepend(e, 1)#0]
+Project [array_insert(e#0, 1, 1) AS array_prepend(e, 1)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index beed5a6e365..63060e61d56 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1397,7 +1397,6 @@ case class ArrayContains(left: Expression, right: 
Expression)
     copy(left = newLeft, right = newRight)
 }
 
-// scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = """
       _FUNC_(array, element) - Add the element at the beginning of the array 
passed as first
@@ -1416,101 +1415,26 @@ case class ArrayContains(left: Expression, right: 
Expression)
   """,
   group = "array_funcs",
   since = "3.5.0")
-case class ArrayPrepend(left: Expression, right: Expression)
-  extends BinaryExpression
-    with ImplicitCastInputTypes
-    with ComplexTypeMergingExpression
-    with QueryErrorsBase {
+case class ArrayPrepend(left: Expression, right: Expression) extends 
RuntimeReplaceable
+  with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase 
{
 
-  override def nullable: Boolean = left.nullable
+  override lazy val replacement: Expression = ArrayInsert(left, Literal(1), 
right)
 
-  @transient protected lazy val elementType: DataType =
-    inputTypes.head.asInstanceOf[ArrayType].elementType
-
-  override def eval(input: InternalRow): Any = {
-    val value1 = left.eval(input)
-    if (value1 == null) {
-      null
-    } else {
-      val value2 = right.eval(input)
-      nullSafeEval(value1, value2)
-    }
-  }
-  override def nullSafeEval(arr: Any, elementData: Any): Any = {
-    val arrayData = arr.asInstanceOf[ArrayData]
-    val numberOfElements = arrayData.numElements() + 1
-    if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-      throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
-    }
-    val finalData = new Array[Any](numberOfElements)
-    finalData.update(0, elementData)
-    arrayData.foreach(elementType, (i: Int, v: Any) => finalData.update(i + 1, 
v))
-    new GenericArrayData(finalData)
-  }
-  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    val leftGen = left.genCode(ctx)
-    val rightGen = right.genCode(ctx)
-    val f = (arr: String, value: String) => {
-      val newArraySize = s"$arr.numElements() + 1"
-      val newArray = ctx.freshName("newArray")
-      val i = ctx.freshName("i")
-      val iPlus1 = s"$i+1"
-      val zero = "0"
-      val allocation = CodeGenerator.createArrayData(
-        newArray,
-        elementType,
-        newArraySize,
-        s" $prettyName failed.")
-      val assignment =
-        CodeGenerator.createArrayAssignment(newArray, elementType, arr, 
iPlus1, i, false)
-      val newElemAssignment =
-        CodeGenerator.setArrayElement(newArray, elementType, zero, value, 
Some(rightGen.isNull))
-      s"""
-         |$allocation
-         |$newElemAssignment
-         |for (int $i = 0; $i < $arr.numElements(); $i ++) {
-         |  $assignment
-         |}
-         |${ev.value} = $newArray;
-         |""".stripMargin
-    }
-    val resultCode = f(leftGen.value, rightGen.value)
-    if(nullable) {
-      val nullSafeEval = leftGen.code + rightGen.code + 
ctx.nullSafeExec(nullable, leftGen.isNull) {
-        s"""
-           |${ev.isNull} = false;
-           |${resultCode}
-           |""".stripMargin
-      }
-      ev.copy(code =
-        code"""
-          |boolean ${ev.isNull} = true;
-          |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
-          |$nullSafeEval
-        """.stripMargin
-      )
-    } else {
-      ev.copy(code =
-        code"""
-          |${leftGen.code}
-          |${rightGen.code}
-          |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
-          |$resultCode
-        """.stripMargin, isNull = FalseLiteral)
+  override def inputTypes: Seq[AbstractDataType] = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
     }
   }
 
-  override def prettyName: String = "array_prepend"
-
-  override protected def withNewChildrenInternal(
-      newLeft: Expression, newRight: Expression): ArrayPrepend =
-    copy(left = newLeft, right = newRight)
-
-  override def dataType: DataType = if (right.nullable) 
left.dataType.asNullable else left.dataType
-
   override def checkInputDataTypes(): TypeCheckResult = {
     (left.dataType, right.dataType) match {
-      case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => 
TypeCheckResult.TypeCheckSuccess
+      case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+        TypeCheckResult.TypeCheckSuccess
       case (ArrayType(e1, _), e2) => DataTypeMismatch(
         errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
         messageParameters = Map(
@@ -1531,16 +1455,12 @@ case class ArrayPrepend(left: Expression, right: 
Expression)
         )
     }
   }
-  override def inputTypes: Seq[AbstractDataType] = {
-    (left.dataType, right.dataType) match {
-      case (ArrayType(e1, hasNull), e2) =>
-        TypeCoercion.findTightestCommonType(e1, e2) match {
-          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
-          case _ => Seq.empty
-        }
-      case _ => Seq.empty
-    }
-  }
+
+  override def prettyName: String = "array_prepend"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayPrepend =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 8f1ff97a78e..485579230c0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1855,50 +1855,6 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), 
null)
   }
 
-  test("SPARK-41233: ArrayPrepend") {
-    val a0 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType))
-    val a1 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
-    val a2 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
-    val a3 = Literal.create(null, ArrayType(StringType))
-
-    checkEvaluation(ArrayPrepend(a0, Literal(0)), Seq(0, 1, 2, 3, 4))
-    checkEvaluation(ArrayPrepend(a1, Literal("a")), Seq("a", "a", "b", "c"))
-    checkEvaluation(ArrayPrepend(a2, Literal(1)), Seq(1))
-    checkEvaluation(ArrayPrepend(a2, Literal(null, IntegerType)), Seq(null))
-    checkEvaluation(ArrayPrepend(a3, Literal("a")), null)
-    checkEvaluation(ArrayPrepend(a3, Literal(null, StringType)), null)
-
-    // complex data types
-    val data = Seq[Array[Byte]](
-      Array[Byte](5, 6),
-      Array[Byte](1, 2),
-      Array[Byte](1, 2),
-      Array[Byte](5, 6))
-    val b0 = Literal.create(
-      data,
-      ArrayType(BinaryType))
-    val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), 
ArrayType(BinaryType))
-    val nullBinary = Literal.create(null, BinaryType)
-    // Calling ArrayPrepend with a null element should result in NULL being 
prepended to the array
-    val dataWithNullPrepended = null +: data
-    checkEvaluation(ArrayPrepend(b0, nullBinary), dataWithNullPrepended)
-    val dataToPrepend1 = Literal.create(Array[Byte](5, 6), BinaryType)
-    checkEvaluation(
-      ArrayPrepend(b1, dataToPrepend1),
-      Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](2, 1), null))
-
-    val c0 = Literal.create(
-      Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
-      ArrayType(ArrayType(IntegerType)))
-    val dataToPrepend2 = Literal.create(Seq[Int](5, 6), ArrayType(IntegerType))
-    checkEvaluation(
-      ArrayPrepend(c0, dataToPrepend2),
-      Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))
-    checkEvaluation(
-      ArrayPrepend(c0, Literal.create(Seq.empty[Int], ArrayType(IntegerType))),
-      Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)))
-  }
-
   test("Array remove") {
     val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
     val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), 
ArrayType(StringType))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 09812194ba7..5386150c8a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2717,6 +2717,22 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
     ).toDF("a", "b")
     checkAnswer(df2.selectExpr("array_prepend(a, b)"),
       Seq(Row(Seq("d", "a", "b", "c")), Row(null), Row(Seq(null, "x", "y", 
"z")), Row(null)))
+    val dataA = Seq[Array[Byte]](
+      Array[Byte](5, 6),
+      Array[Byte](1, 2),
+      Array[Byte](1, 2),
+      Array[Byte](5, 6))
+    val dataB = Seq[Array[Int]](Array[Int](1, 2), Array[Int](3, 4))
+    val df3 = Seq((dataA, dataB)).toDF("a", "b")
+    val dataToPrepend = Array[Byte](5, 6)
+    checkAnswer(
+      df3.select(array_prepend($"a", null), array_prepend($"a", 
dataToPrepend)),
+      Seq(Row(null +: dataA, dataToPrepend +: dataA)))
+    checkAnswer(
+      df3.select(array_prepend($"b", Array.empty[Int]), array_prepend($"b", 
Array[Int](5, 6))),
+      Seq(Row(
+        Seq(Seq.empty[Int], Seq[Int](1, 2), Seq[Int](3, 4)),
+        Seq(Seq[Int](5, 6), Seq[Int](1, 2), Seq[Int](3, 4)))))
   }
 
   test("array remove") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to