This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 8ef4023 [SPARK-34794][SQL] Fix lambda variable name issues in nested DataFrame functions 8ef4023 is described below commit 8ef4023683dee537a40d376d93c329a802a929bd Author: dsolow <dso...@sayari.com> AuthorDate: Wed May 5 12:46:13 2021 +0900 [SPARK-34794][SQL] Fix lambda variable name issues in nested DataFrame functions ### What changes were proposed in this pull request? To fix lambda variable name issues in nested DataFrame functions, this PR modifies code to use a global counter for `LambdaVariables` names created by higher order functions. This is the rework of #31887. Closes #31887. ### Why are the changes needed? This moves away from the current hard-coded variable names which break on nested function calls. There is currently a bug where nested transforms in particular fail (the inner variable shadows the outer variable) For this query: ``` val df = Seq( (Seq(1,2,3), Seq("a", "b", "c")) ).toDF("numbers", "letters") df.select( f.flatten( f.transform( $"numbers", (number: Column) => { f.transform( $"letters", (letter: Column) => { f.struct( number.as("number"), letter.as("letter") ) } ) } ) ).as("zipped") ).show(10, false) ``` This is the current (incorrect) output: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]| +------------------------------------------------------------------------+ ``` And this is the correct output after fix: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]| +------------------------------------------------------------------------+ ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added the new test in `DataFrameFunctionsSuite`. Closes #32424 from maropu/pr31887. Lead-authored-by: dsolow <dso...@sayari.com> Co-authored-by: Takeshi Yamamuro <yamam...@apache.org> Co-authored-by: dmsolow <dso...@sayarianalytics.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> (cherry picked from commit f550e03b96638de93381734c4eada2ace02d9a4f) Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../expressions/higherOrderFunctions.scala | 12 ++++++++++- .../scala/org/apache/spark/sql/functions.scala | 12 +++++------ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 23 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e5cf8c0..a530ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.mutable @@ -52,6 +52,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String]) override def sql: String = name } +object UnresolvedNamedLambdaVariable { + + // Counter to ensure lambda variable names are unique + private val nextVarNameId = new AtomicInteger(0) + + def freshVarName(name: String): String = { + s"${name}_${nextVarNameId.getAndIncrement()}" + } +} + /** * A named lambda variable. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bb77c7e..f6d6200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3489,22 +3489,22 @@ object functions { } private def createLambda(f: Column => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) val function = f(Column(x)).expr LambdaFunction(function, Seq(x)) } private def createLambda(f: (Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) val function = f(Column(x), Column(y)).expr LambdaFunction(function, Seq(x, y)) } private def createLambda(f: (Column, Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) - val z = UnresolvedNamedLambdaVariable(Seq("z")) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) val function = f(Column(x), Column(y), Column(z)).expr LambdaFunction(function, Seq(x, y, z)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ac98d3f..1a468a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3621,6 +3621,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(map(map_entries($"m"), lit(1))), Row(Map(Seq(Row(1, "a")) -> 1))) } + + test("SPARK-34794: lambda variable name issues in nested functions") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters") + + checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) => + transform($"letters", (letter: Column) => + struct(number, letter))))), + Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b")))) + ) + checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) => + transform($"letters", (letter: Column, j: Column) => + struct(number + j, concat(letter, i)))))), + Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1")))) + ) + + val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2") + + checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) => + map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) => + ov1 + iv1 + ov2 + iv2))), + Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10)))) + ) + } } object DataFrameFunctionsSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org