[ 
https://issues.apache.org/jira/browse/SPARK-26370?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16721575#comment-16721575
 ] 

ASF GitHub Bot commented on SPARK-26370:
----------------------------------------

asfgit closed pull request #23320: [SPARK-26370][SQL] Fix resolution of 
higher-order function for the same identifier.
URL: https://github.com/apache/spark/pull/23320
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index a8a7bbd9f9cd0..1cd7f412bb678 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends 
Rule[LogicalPlan] {
       val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
       l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
 
-    case u @ UnresolvedAttribute(name +: nestedFields) =>
+    case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
       parentLambdaMap.get(canonicalizer(name)) match {
         case Some(lambda) =>
           nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
             ExtractValue(expr, Literal(fieldName), conf.resolver)
           }
-        case None => u
+        case None =>
+          UnresolvedAttribute(u.nameParts)
       }
 
     case _ =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a8639d29f964d..7141b6e996389 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute, UnresolvedException}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.array.ByteArrayMethods
 
+/**
+ * A placeholder of lambda variables to prevent unexpected resolution of 
[[LambdaFunction]].
+ */
+case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
+  extends LeafExpression with NamedExpression with Unevaluable {
+
+  override def name: String =
+    nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
+  override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
+  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
+  override def qualifier: Seq[String] = throw new UnresolvedException(this, 
"qualifier")
+  override def toAttribute: Attribute = throw new UnresolvedException(this, 
"toAttribute")
+  override def newInstance(): NamedExpression = throw new 
UnresolvedException(this, "newInstance")
+  override lazy val resolved = false
+
+  override def toString: String = s"lambda '$name"
+
+  override def sql: String = name
+}
+
 /**
  * A named lambda variable.
  */
@@ -79,7 +101,7 @@ case class LambdaFunction(
 
 object LambdaFunction {
   val identity: LambdaFunction = {
-    val id = UnresolvedAttribute.quoted("id")
+    val id = UnresolvedNamedLambdaVariable(Seq("id"))
     LambdaFunction(id, Seq(id))
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 672bffcfc0cad..8959f78b656d2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends 
SqlBaseBaseVisitor[AnyRef] with Logging
    */
   override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
     val arguments = ctx.IDENTIFIER().asScala.map { name =>
-      UnresolvedAttribute.quoted(name.getText)
+      
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
     }
-    LambdaFunction(expression(ctx.expression), arguments)
+    val function = expression(ctx.expression).transformUp {
+      case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+    }
+    LambdaFunction(function, arguments)
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
index c4171c75ecd03..a5847ba7c522d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
     comparePlans(Analyzer.execute(plan(e1)), plan(e2))
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("resolution - no op") {
     checkExpression(key, key)
   }
 
   test("resolution - simple") {
-    val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: 
Nil))
+    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
     val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
     checkExpression(in, out)
   }
 
   test("resolution - nested") {
     val in = ArrayTransform(values2, LambdaFunction(
-      ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 
'x.attr :: Nil))
+      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), 
lv('x) :: Nil))
     val out = ArrayTransform(values2, LambdaFunction(
       ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), 
lvArray :: Nil))
     checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
 
   test("fail - name collisions") {
     val p = plan(ArrayTransform(values1,
-      LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
+      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
     val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
     assert(msg.contains("arguments should not have names that are semantically 
the same"))
   }
 
   test("fail - lambda arguments") {
     val p = plan(ArrayTransform(values1,
-      LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 
'z.attr :: Nil)))
+      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: 
Nil)))
     val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
     assert(msg.contains("does not match the number of arguments expected"))
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
index ee0d04da3e46c..748075bfd6a68 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, 
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, 
LambdaFunction, Literal, MapFilter, NamedExpression, Or}
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, 
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, 
LambdaFunction, Literal, MapFilter, NamedExpression, Or, 
UnresolvedNamedLambdaVariable}
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
@@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends 
PlanTest {
     testProjection(originalExpr = column, expectedExpr = column)
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("replace nulls in lambda function of ArrayFilter") {
-    testHigherOrderFunc('a, ArrayFilter, Seq('e))
+    testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
   }
 
   test("replace nulls in lambda function of ArrayExists") {
-    testHigherOrderFunc('a, ArrayExists, Seq('e))
+    testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
   }
 
   test("replace nulls in lambda function of MapFilter") {
-    testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
+    testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
   }
 
   test("inability to replace nulls in arbitrary higher-order function") {
     val lambdaFunc = LambdaFunction(
-      function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
-      arguments = Seq[NamedExpression]('e))
+      function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
+      arguments = Seq[NamedExpression](lv('e)))
     val column = ArrayTransform('a, lambdaFunc)
     testProjection(originalExpr = column, expectedExpr = column)
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index b4df22c5b29fa..8bcc69d580d83 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
     intercept("foo(a x)", "extraneous input 'x'")
   }
 
+  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
   test("lambda functions") {
-    assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
-    assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 
'y.attr)))
+    assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
+    assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), 
lv('y))))
   }
 
   test("window function expressions") {
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
index 35740094ba53e..86a578ca013df 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -85,7 +85,7 @@ FROM various_maps
 struct<>
 -- !query 5 output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map1`, 
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), 
`k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due 
to argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map1`, 
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, 
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument 
data type mismatch: The input to function map_zip_with should have been two 
maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
 
 
 -- !query 6
@@ -113,7 +113,7 @@ FROM various_maps
 struct<>
 -- !query 8 output
 org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map2`, 
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, 
NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to 
argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,35), 
int].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map2`, 
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, 
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument 
data type mismatch: The input to function map_zip_with should have been two 
maps with compatible key types, but the key types are [decimal(36,35), int].; 
line 1 pos 7
 
 
 -- !query 9
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e6d1a038a5918..b7fc9570af919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     }
     assert(ex.getMessage.contains("Cannot use null as map key"))
   }
+
+  test("SPARK-26370: Fix resolution of higher-order function for the same 
identifier") {
+    val df = Seq(
+      (Seq(1, 9, 8, 7), 1, 2),
+      (Seq(5, 9, 7), 2, 2),
+      (Seq.empty, 3, 2),
+      (null, 4, 2)
+    ).toDF("i", "x", "d")
+
+    checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
+      Seq(
+        Row(1, true),
+        Row(2, false),
+        Row(3, false),
+        Row(4, null)))
+    checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
+      Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
+    checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
+      Seq(Row(1)))
+  }
 }
 
 object DataFrameFunctionsSuite {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Fix resolution of higher-order function for the same identifier.
> ----------------------------------------------------------------
>
>                 Key: SPARK-26370
>                 URL: https://issues.apache.org/jira/browse/SPARK-26370
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 2.4.0
>            Reporter: Takuya Ueshin
>            Priority: Major
>
> When using a higher-order function with the same variable name as the 
> existing columns in {{Filter}} or something which uses 
> {{Analyzer.resolveExpressionBottomUp}} during the resolution, e.g.,:
> {code}
> val df = Seq(
>   (Seq(1, 9, 8, 7), 1, 2),
>   (Seq(5, 9, 7), 2, 2),
>   (Seq.empty, 3, 2),
>   (null, 4, 2)
> ).toDF("i", "x", "d")
> checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
>   Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
> checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
>   Seq(Row(1)))
> {code}
> the following exception happens:
> {code:java}
> java.lang.ClassCastException: 
> org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to 
> org.apache.spark.sql.catalyst.expressions.NamedExpression
>   at 
> scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
>   at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
>   at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
>   at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
>   at scala.collection.TraversableLike.map(TraversableLike.scala:237)
>   at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
>   at scala.collection.AbstractTraversable.map(Traversable.scala:108)
>   at 
> org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147)
>   at 
> scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
>   at scala.collection.immutable.List.foreach(List.scala:392)
>   at scala.collection.TraversableLike.map(TraversableLike.scala:237)
>   at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
>   at scala.collection.immutable.List.map(List.scala:298)
>   at 
> org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145)
>   at 
> org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145)
>   at 
> org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369)
>   at 
> org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369)
>   at 
> org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176)
>   at 
> org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176)
>   at 
> org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369)
>   at 
> org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387)
>   at 
> org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190)
>   at 
> org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185)
>   at 
> org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369)
>   at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown
>  Source)
>   at 
> org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216)
>   at 
> org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215)
> ...
> {code}
> because the {{UnresolvedAttribute}} s in {{LambdaFunction}} are unexpectedly 
> resolved by the rule.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to