zhengruifeng commented on code in PR #38865:
URL: https://github.com/apache/spark/pull/38865#discussion_r1042110305


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4600,3 +4600,133 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = 
newLeft, right = newRight)
 }
+
+/**
+ * Given an array, and another element append the element at the end of the 
array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array, element) - Append the element",
+  examples =
+    """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+       ["b","d","c","a","d"]
+
+  """,
+  since = "3.4.0",
+  group = "array_funcs")
+case class ArrayAppend(left: Expression, right: Expression)
+  extends BinaryExpression
+  with ImplicitCastInputTypes
+  with ComplexTypeMergingExpression
+  with QueryErrorsBase {
+  override def prettyName: String = "array_append"
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    val value2 = right.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      nullSafeEval(value1, value2)
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) => if (e1.sameType(e2)) {

Review Comment:
   would you minding reformat to
   ```
   case (ArrayType(e1, _), e2) if (e1.sameType(e2)) => ...
   case (ArrayType(e1, _), e2) => ...
   case _ => ...
   
   ```
   
   



##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala:
##########
@@ -5237,6 +5237,92 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       )
     )
   }
+
+  test("array_append -> Unit Test cases for the function ") {
+    val df1 = Seq((Array[Int](3, 2, 5, 1, 2), 3)).toDF("a", "b")
+    checkAnswer(df1.select(array_append(col("a"), col("b"))), Seq(Row(Seq(3, 
2, 5, 1, 2, 3))))
+    val df2 = Seq((Array[String]("a", "b", "c"), "d")).toDF("a", "b")
+    checkAnswer(df2.select(array_append(col("a"), col("b"))), Seq(Row(Seq("a", 
"b", "c", "d"))))
+    val df3 = Seq((Array[String]("a", "b", "c"), 3)).toDF("a", "b")
+    checkError(
+      exception = intercept[AnalysisException] {
+        df3.select(array_append(col("a"), col("b")))
+      },
+      errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+      parameters = Map(
+        "functionName" -> "`array_append`",
+        "dataType" -> "\"ARRAY\"",
+        "leftType" -> "\"ARRAY<STRING>\"",
+        "rightType" -> "\"INT\"",
+        "sqlExpr" -> "\"array_append(a, b)\"")
+    )
+
+    checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 
2, 3))))
+
+    checkAnswer(df2.selectExpr("array_append(a, b)"), Seq(Row(Seq("a", "b", 
"c", "d"))))
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        df3.selectExpr("array_append(a, b)")
+      },
+      errorClass = "DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES",
+      parameters = Map(
+        "functionName" -> "`array_append`",
+        "leftType" -> "\"ARRAY<STRING>\"",
+        "rightType" -> "\"INT\"",
+        "sqlExpr" -> "\"array_append(a, b)\"",
+        "dataType" -> "\"ARRAY\""
+      ),
+      context = ExpectedContext(
+        fragment = "array_append(a, b)",
+        start = 0,
+        stop = 17
+      )
+    )
+    // Adding null check Unit Tests
+    val df4 = Seq((Array[String]("a", "b", "c"), "d"),
+      (null, "d"),
+      (Array[String]("x", "y", "z"), null),
+      (null, null)
+    ).toDF("a", "b")
+    checkAnswer(df4.selectExpr("array_append(a, b)"),
+      Seq(Row(Seq("a", "b", "c", "d")), Row(null), Row(Seq("x", "y", "z", 
null)), Row(null)))
+
+    val df5 = Seq((Array[Double](3d, 2d, 5d, 1d, 2d), 3)).toDF("a", "b")
+    checkAnswer(df5.selectExpr("array_append(a, b)"),
+      Seq(Row(Seq(3d, 2d, 5d, 1d, 2d, 3d))))
+
+    val df6 = Seq(("x", "y")).toDF("a", "b")
+    checkError(
+      exception = intercept[AnalysisException] {
+        df6.selectExpr("array_append(a, b)")
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"array_append(a, b)\"",
+        "paramIndex" -> "0",
+        "requiredType" -> "\"ARRAY\"",
+        "inputSql" -> "\"a\"",
+        "inputType" -> "\"STRING\""
+      ),
+      context = ExpectedContext(
+        fragment = "array_append(a, b)",
+        start = 0,
+        stop = 17
+      )
+    )
+
+    val df7 = Seq((Array[Int](3, 2, 5, 1, 2), 3d)).toDF("a", "b")
+    checkAnswer(df7.select(array_append(col("a"), col("b"))),
+      Seq(Row(Seq(3d, 2d, 5d, 1d, 2d, 3d))))
+
+    val df8 = Seq((Array[Double](3d, 2d, 5d, 1d, 2d), 3)).toDF("a", "b")
+    checkAnswer(df8.select(array_append(col("a"), col("b"))),
+      Seq(Row(Seq(3d, 2d, 5d, 1d, 2d, 3d))))
+

Review Comment:
   nit: blank lines



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -4600,3 +4600,133 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = 
newLeft, right = newRight)
 }
+
+/**
+ * Given an array, and another element append the element at the end of the 
array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array, element) - Append the element",
+  examples =
+    """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'), 'd');
+       ["b","d","c","a","d"]
+
+  """,
+  since = "3.4.0",
+  group = "array_funcs")
+case class ArrayAppend(left: Expression, right: Expression)
+  extends BinaryExpression
+  with ImplicitCastInputTypes
+  with ComplexTypeMergingExpression
+  with QueryErrorsBase {
+  override def prettyName: String = "array_append"
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    val value2 = right.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      nullSafeEval(value1, value2)
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) => if (e1.sameType(e2)) {
+        TypeCheckResult.TypeCheckSuccess
+      }
+      else {
+        DataTypeMismatch(
+          errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+          messageParameters = Map(
+            "functionName" -> toSQLId(prettyName),
+            "leftType" -> toSQLType(left.dataType),
+            "rightType" -> toSQLType(right.dataType),
+            "dataType" -> toSQLType(ArrayType)
+          ))
+      }
+      case _ =>
+        DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> "0",
+            "requiredType" -> toSQLType(ArrayType),
+            "inputSql" -> toSQLExpr(left),
+            "inputType" -> toSQLType(left.dataType)
+          )
+        )
+    }
+  }
+
+  protected def withNewChildrenInternal(newLeft: Expression, newRight: 
Expression): ArrayAppend =
+    copy(left = newLeft, right = newRight)
+
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    val arrayData = input1.asInstanceOf[ArrayData]
+    val arrayElementType = dataType.asInstanceOf[ArrayType].elementType
+    val elementData = input2
+    val numberOfElements = arrayData.numElements() + 1
+    if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+      throw 
QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements)
+    }
+    val finalData = new Array[Any](numberOfElements)
+    arrayData.foreach(arrayElementType, finalData.update)
+    finalData.update(numberOfElements - 1, elementData)
+    new GenericArrayData(finalData)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+
+    val f = (left: String, right: String) => {
+      val expr = ctx.addReferenceObj("arraysAppendExpr", this)
+      s"${ev.value} = (ArrayData)$expr.nullSafeEval($left, $right);"
+    }
+
+    val leftGen = left.genCode(ctx)
+    val rightGen = right.genCode(ctx)

Review Comment:
   I was try to refer to other higher order functions which accept array & 
element as inputs:
   
   ```
   scala> spark.sql("""SELECT 
a,b,array_contains(a,b),array_position(a,b),array_remove(a,b) FROM VALUES 
(ARRAY(1, NULL, 3), 1), (ARRAY(1, NULL, 3), NULL), (NULL, 1), (NULL, NULL) AS 
tab(a, b)""").show
   
+------------+----+--------------------+--------------------+------------------+
   |           a|   b|array_contains(a, b)|array_position(a, b)|array_remove(a, 
b)|
   
+------------+----+--------------------+--------------------+------------------+
   |[1, null, 3]|   1|                true|                   1|         [null, 
3]|
   |[1, null, 3]|null|                null|                null|              
null|
   |        null|   1|                null|                null|              
null|
   |        null|null|                null|                null|              
null|
   
+------------+----+--------------------+--------------------+------------------+
   ```
   
   It seems the expected output should be `NULL`, if one of the two input 
columns is `NULL`.
   
   @HyukjinKwon  @cloud-fan   



##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala:
##########
@@ -5237,6 +5237,59 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       )
     )
   }
+
+  test("SPARK-41232 array_append -> Unit Test cases for the function ") {
+    val df1 = Seq((Array[Int](3, 2, 5, 1, 2), 3)).toDF("a", "b")

Review Comment:
   good question, I try with other build-functions, and think we should follow 
`array_remove` in this case:
   
   ```
   scala> spark.sql("""SELECT array_remove(a,b) FROM VALUES (ARRAY(1, 2, 3), 3) 
AS tab(a, b)""").show
   +------------------+
   |array_remove(a, b)|
   +------------------+
   |            [1, 2]|
   +------------------+
   
   
   scala> spark.sql("""SELECT array_remove(a,b) FROM VALUES (ARRAY(1.0, 2.0, 
3.0), 3) AS tab(a, b)""").show
   org.apache.spark.sql.AnalysisException: 
[DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES] Cannot resolve "array_remove(a, 
b)" due to data type mismatch: Input to `array_remove` should have been "ARRAY" 
followed by a value with same element type, but it's ["ARRAY<DECIMAL(2,1)>", 
"INT"].; line 1 pos 7;
   'Project [unresolvedalias(array_remove(a#122, b#123), None)]
   +- SubqueryAlias tab
      +- LocalRelation [a#122, b#123]
   
     at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
   ...
   
   scala> spark.sql("""SELECT array_remove(a,b) FROM VALUES (ARRAY(1, 2, 3), 
3.0) AS tab(a, b)""").show
   org.apache.spark.sql.AnalysisException: 
[DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES] Cannot resolve "array_remove(a, 
b)" due to data type mismatch: Input to `array_remove` should have been "ARRAY" 
followed by a value with same element type, but it's ["ARRAY<INT>", 
"DECIMAL(2,1)"].; line 1 pos 7;
   'Project [unresolvedalias(array_remove(a#124, b#125), None)]
   +- SubqueryAlias tab
      +- LocalRelation [a#124, b#125]
   
     at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
   ...
   
   scala> spark.sql("""SELECT array_position(a,b),a,b FROM VALUES (ARRAY(1, 
NULL, 3), 1.0) AS tab(a, b)""")
   org.apache.spark.sql.AnalysisException: 
[DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES] Cannot resolve "array_position(a, 
b)" due to data type mismatch: Input to `array_position` should have been 
"ARRAY" followed by a value with same element type, but it's ["ARRAY<INT>", 
"DECIMAL(2,1)"].; line 1 pos 7;
   'Project [unresolvedalias(array_position(a#564, b#565), None), a#564, b#565]
   +- SubqueryAlias tab
      +- LocalRelation [a#564, b#565]
   
     at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
     at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:249)
     at 
org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:236)
   ...
   
   scala> spark.sql("""SELECT array_position(a,b),a,b FROM VALUES (ARRAY(1.0, 
NULL, 3.0), 1) AS tab(a, b)""")
   org.apache.spark.sql.AnalysisException: 
[DATATYPE_MISMATCH.ARRAY_FUNCTION_DIFF_TYPES] Cannot resolve "array_position(a, 
b)" due to data type mismatch: Input to `array_position` should have been 
"ARRAY" followed by a value with same element type, but it's 
["ARRAY<DECIMAL(2,1)>", "INT"].; line 1 pos 7;
   'Project [unresolvedalias(array_position(a#566, b#567), None), a#566, b#567]
   +- SubqueryAlias tab
      +- LocalRelation [a#566, b#567]
   
     at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
   ...
   
   scala> spark.sql("""SELECT array_contains(a,b),a,b FROM VALUES (ARRAY(1.0, 
NULL, 3.0), 1) AS tab(a, b)""")
   res50: org.apache.spark.sql.DataFrame = [array_contains(a, b): boolean, a: 
array<decimal(2,1)> ... 1 more field]
   
   scala> spark.sql("""SELECT array_contains(a,b),a,b FROM VALUES (ARRAY(1, 
NULL, 3), 1.0) AS tab(a, b)""")
   res51: org.apache.spark.sql.DataFrame = [array_contains(a, b): boolean, a: 
array<int> ... 1 more field]
   ```
   
   So currently behavior (exactly matching) looks reasonable to me.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to