asfgit closed pull request #23320: [SPARK-26370][SQL] Fix resolution of
higher-order function for the same identifier.
URL: https://github.com/apache/spark/pull/23320
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index a8a7bbd9f9cd0..1cd7f412bb678 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends
Rule[LogicalPlan] {
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
- case u @ UnresolvedAttribute(name +: nestedFields) =>
+ case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
parentLambdaMap.get(canonicalizer(name)) match {
case Some(lambda) =>
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), conf.resolver)
}
- case None => u
+ case None =>
+ UnresolvedAttribute(u.nameParts)
}
case _ =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a8639d29f964d..7141b6e996389 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion,
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion,
UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
+/**
+ * A placeholder of lambda variables to prevent unexpected resolution of
[[LambdaFunction]].
+ */
+case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
+ extends LeafExpression with NamedExpression with Unevaluable {
+
+ override def name: String =
+ nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this,
"dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this,
"nullable")
+ override def qualifier: Seq[String] = throw new UnresolvedException(this,
"qualifier")
+ override def toAttribute: Attribute = throw new UnresolvedException(this,
"toAttribute")
+ override def newInstance(): NamedExpression = throw new
UnresolvedException(this, "newInstance")
+ override lazy val resolved = false
+
+ override def toString: String = s"lambda '$name"
+
+ override def sql: String = name
+}
+
/**
* A named lambda variable.
*/
@@ -79,7 +101,7 @@ case class LambdaFunction(
object LambdaFunction {
val identity: LambdaFunction = {
- val id = UnresolvedAttribute.quoted("id")
+ val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 672bffcfc0cad..8959f78b656d2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends
SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.IDENTIFIER().asScala.map { name =>
- UnresolvedAttribute.quoted(name.getText)
+
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
- LambdaFunction(expression(ctx.expression), arguments)
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments)
}
/**
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
index c4171c75ecd03..a5847ba7c522d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("resolution - no op") {
checkExpression(key, key)
}
test("resolution - simple") {
- val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr ::
Nil))
+ val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
checkExpression(in, out)
}
test("resolution - nested") {
val in = ArrayTransform(values2, LambdaFunction(
- ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)),
'x.attr :: Nil))
+ ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)),
lv('x) :: Nil))
val out = ArrayTransform(values2, LambdaFunction(
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)),
lvArray :: Nil))
checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
test("fail - name collisions") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("arguments should not have names that are semantically
the same"))
}
test("fail - lambda arguments") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr ::
'z.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) ::
Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("does not match the number of arguments expected"))
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
index ee0d04da3e46c..748075bfd6a68 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists,
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If,
LambdaFunction, Literal, MapFilter, NamedExpression, Or}
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists,
ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If,
LambdaFunction, Literal, MapFilter, NamedExpression, Or,
UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
@@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends
PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("replace nulls in lambda function of ArrayFilter") {
- testHigherOrderFunc('a, ArrayFilter, Seq('e))
+ testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
}
test("replace nulls in lambda function of ArrayExists") {
- testHigherOrderFunc('a, ArrayExists, Seq('e))
+ testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
}
test("replace nulls in lambda function of MapFilter") {
- testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
+ testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
}
test("inability to replace nulls in arbitrary higher-order function") {
val lambdaFunc = LambdaFunction(
- function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
- arguments = Seq[NamedExpression]('e))
+ function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
+ arguments = Seq[NamedExpression](lv('e)))
val column = ArrayTransform('a, lambdaFunc)
testProjection(originalExpr = column, expectedExpr = column)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index b4df22c5b29fa..8bcc69d580d83 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
intercept("foo(a x)", "extraneous input 'x'")
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("lambda functions") {
- assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
- assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr,
'y.attr)))
+ assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
+ assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x),
lv('y))))
}
test("window function expressions") {
diff --git
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
index 35740094ba53e..86a578ca013df 100644
---
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map1`,
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(),
`k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due
to argument data type mismatch: The input to function map_zip_with should have
been two maps with compatible key types, but the key types are [decimal(36,0),
decimal(36,35)].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map1`,
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k,
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument
data type mismatch: The input to function map_zip_with should have been two
maps with compatible key types, but the key types are [decimal(36,0),
decimal(36,35)].; line 1 pos 7
-- !query 6
@@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map2`,
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`,
NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to
argument data type mismatch: The input to function map_zip_with should have
been two maps with compatible key types, but the key types are [decimal(36,35),
int].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map2`,
various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k,
NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument
data type mismatch: The input to function map_zip_with should have been two
maps with compatible key types, but the key types are [decimal(36,35), int].;
line 1 pos 7
-- !query 9
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e6d1a038a5918..b7fc9570af919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSQLContext {
}
assert(ex.getMessage.contains("Cannot use null as map key"))
}
+
+ test("SPARK-26370: Fix resolution of higher-order function for the same
identifier") {
+ val df = Seq(
+ (Seq(1, 9, 8, 7), 1, 2),
+ (Seq(5, 9, 7), 2, 2),
+ (Seq.empty, 3, 2),
+ (null, 4, 2)
+ ).toDF("i", "x", "d")
+
+ checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
+ Seq(
+ Row(1, true),
+ Row(2, false),
+ Row(3, false),
+ Row(4, null)))
+ checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
+ checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(1)))
+ }
}
object DataFrameFunctionsSuite {
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]