This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 4bf5cd6  [SPARK-34829][SQL] Fix higher order function results
4bf5cd6 is described below

commit 4bf5cd6bb4acdb9b1e84213126f6063a469128e5
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Sun Mar 28 10:01:09 2021 -0700

    [SPARK-34829][SQL] Fix higher order function results
    
    ### What changes were proposed in this pull request?
    This PR fixes a correctness issue with higher order functions. The results 
of function expressions needs to be copied in some higher order functions as 
such an expression can return with internal buffers and higher order functions 
can call multiple times the expression.
    The issue was discovered with typed `ScalaUDF`s after 
https://github.com/apache/spark/pull/28979.
    
    ### Why are the changes needed?
    To fix a bug.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, some queries return the right results again.
    
    ### How was this patch tested?
    Added new UT.
    
    Closes #31955 from peter-toth/SPARK-34829-fix-scalaudf-resultconversion.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 33821903490650463f47d03d20d2b16f77360e99)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../expressions/higherOrderFunctions.scala         | 14 ++++--
 .../org/apache/spark/sql/DataFrameSuite.scala      | 56 ++++++++++++++++++++++
 2 files changed, 65 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 4454afb..ba447ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -277,7 +277,8 @@ case class ArrayTransform(
       if (indexVar.isDefined) {
         indexVar.get.value.set(i)
       }
-      result.update(i, f.eval(inputRow))
+      val v = InternalRow.copyValue(f.eval(inputRow))
+      result.update(i, v)
       i += 1
     }
     result
@@ -796,7 +797,7 @@ case class TransformKeys(
     while (i < map.numElements) {
       keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
       valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
-      val result = functionForEval.eval(inputRow)
+      val result = InternalRow.copyValue(functionForEval.eval(inputRow))
       resultKeys.update(i, result)
       i += 1
     }
@@ -843,7 +844,8 @@ case class TransformValues(
     while (i < map.numElements) {
       keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
       valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
-      resultValues.update(i, functionForEval.eval(inputRow))
+      val v = InternalRow.copyValue(functionForEval.eval(inputRow))
+      resultValues.update(i, v)
       i += 1
     }
     new ArrayBasedMapData(map.keyArray(), resultValues)
@@ -1026,7 +1028,8 @@ case class MapZipWith(left: Expression, right: 
Expression, function: Expression)
       value1Var.value.set(v1)
       value2Var.value.set(v2)
       keys.update(i, key)
-      values.update(i, functionForEval.eval(inputRow))
+      val v = InternalRow.copyValue(functionForEval.eval(inputRow))
+      values.update(i, v)
       i += 1
     }
     new ArrayBasedMapData(keys, values)
@@ -1098,7 +1101,8 @@ case class ZipWith(left: Expression, right: Expression, 
function: Expression)
           } else {
             rightElemVar.value.set(null)
           }
-          result.update(i, f.eval(input))
+          val v = InternalRow.copyValue(f.eval(input))
+          result.update(i, v)
           i += 1
         }
         result
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2fa840c..61a528c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2628,6 +2628,62 @@ class DataFrameSuite extends QueryTest
     )
     checkAnswer(test.select($"best_name.name"), Row("bob") :: Row("bob") :: 
Row("sam") :: Nil)
   }
+
+  test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order 
functions work") {
+    val reverse = udf((s: String) => s.reverse)
+    val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse))
+
+    val df = Seq(Array("abc", "def")).toDF("array")
+    val test = df.select(transform(col("array"), s => reverse(s)))
+    checkAnswer(test, Row(Array("cba", "fed")) :: Nil)
+
+    val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array")
+    val test2 = df2.select(transform(col("array"), b => reverse2(b)))
+    checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil)
+
+    val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map")
+    val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s)))
+    checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil)
+
+    val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map")
+    val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b)))
+    checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil)
+
+    val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map")
+    val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s)))
+    checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil)
+
+    val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map")
+    val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b)))
+    checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil)
+
+    val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++ 
s2.reverse)
+    val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++ 
b2.s.reverse))
+
+    val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 -> 
"jkl"))).toDF("map1", "map2")
+    val test7 =
+      df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) => 
reverseThenConcat(s1, s2)))
+    checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil)
+
+    val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")),
+      Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2")
+    val test8 =
+      df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) => 
reverseThenConcat2(b1, b2)))
+    checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil)
+
+    val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1", 
"array2")
+    val test9 =
+      df9.select(zip_with(col("array1"), col("array2"), (s1, s2) => 
reverseThenConcat(s1, s2)))
+    checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil)
+
+    val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"), 
Bar2("jkl"))))
+      .toDF("array1", "array2")
+    val test10 =
+      df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => 
reverseThenConcat2(b1, b2)))
+    checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)
+
+case class Bar2(s: String)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to