This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new f10fbb2 [SPARK-35636][SQL] Lambda keys should not be referenced
outside of the lambda function
f10fbb2 is described below
commit f10fbb2055ecf4163b5b3d236e69138dfb228e1a
Author: Karen Feng <[email protected]>
AuthorDate: Fri Jun 4 15:44:32 2021 +0900
[SPARK-35636][SQL] Lambda keys should not be referenced outside of the
lambda function
Sets `references` for `NamedLambdaVariable` and `LambdaFunction`.
| Expression | NamedLambdaVariable | LambdaFunction |
| --- | --- | --- |
| References before | None | All function references |
| References after | self.toAttribute | Function references minus
arguments' references |
In `NestedColumnAliasing`, this means that `ExtractValue(ExtractValue(attr,
lv: NamedLambdaVariable), ...)` now references both `attr` and `lv`, rather
than just `attr`. As a result, it will not be included in the nested column
references.
Before, lambda key was referenced outside of lambda function.
Before:
```
Project [transform(keys#0, lambdafunction(_extract_v1#0, lambda key#0,
false)) AS a#0]
+- 'Join Cross
:- Project [kvs#0[lambda key#0].v1 AS _extract_v1#0]
: +- LocalRelation <empty>, [kvs#0]
+- LocalRelation <empty>, [keys#0]
```
After:
```
Project [transform(keys#418, lambdafunction(kvs#417[lambda key#420].v1,
lambda key#420, false)) AS a#419]
+- Join Cross
:- LocalRelation <empty>, [kvs#417]
+- LocalRelation <empty>, [keys#418]
```
Before:
```
Project [transform(keys#0, lambdafunction(kvs#0[lambda key#0].v1, lambda
key#0, false)) AS a#0]
+- GlobalLimit 5
+- LocalLimit 5
+- Project [keys#0, _extract_v1#0 AS _extract_v1#0]
+- GlobalLimit 5
+- LocalLimit 5
+- Project [kvs#0[lambda key#0].v1 AS _extract_v1#0, keys#0]
+- LocalRelation <empty>, [kvs#0, keys#0]
```
After:
```
Project [transform(keys#428, lambdafunction(kvs#427[lambda key#430].v1,
lambda key#430, false)) AS a#429]
+- GlobalLimit 5
+- LocalLimit 5
+- Project [keys#428, kvs#427]
+- GlobalLimit 5
+- LocalLimit 5
+- LocalRelation <empty>, [kvs#427, keys#428]
```
No
Scala unit tests for the examples above
Closes #32773 from karenfeng/SPARK-35636.
Authored-by: Karen Feng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 53a758b51b79c52ac5d4bf3fad72765e55607f36)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../expressions/higherOrderFunctions.scala | 10 +++++++
.../optimizer/NestedColumnAliasingSuite.scala | 32 +++++++++++++++++++++-
2 files changed, 41 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a4e069d..0a6376f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -84,6 +84,9 @@ case class NamedLambdaVariable(
AttributeReference(name, dataType, nullable, Metadata.empty)(exprId,
Seq.empty)
}
+ // Check if this lambda variable is referenced outside the lambda function
it is bound to
+ override def references: AttributeSet = AttributeSet(toAttribute)
+
override def eval(input: InternalRow): Any = value.get
override def toString: String = s"lambda $name#${exprId.id}$typeSuffix"
@@ -108,6 +111,13 @@ case class LambdaFunction(
override def dataType: DataType = function.dataType
override def nullable: Boolean = function.nullable
+ // Check if lambda variables bound to this lambda function are referenced in
the wrong scope
+ override def references: AttributeSet = if (resolved) {
+ function.references -- AttributeSet(arguments.flatMap(_.references))
+ } else {
+ super.references
+ }
+
lazy val bound: Boolean = arguments.forall(_.resolved)
override def eval(input: InternalRow): Any = function.eval(input)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index c83ab37..08e442d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField,
StructType}
class NestedColumnAliasingSuite extends SchemaPruningTest {
@@ -684,6 +685,35 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
).analyze
comparePlans(optimized2, expected2)
}
+
+ test("SPARK-35636: do not push lambda key out of lambda function") {
+ val rel = LocalRelation(
+ 'kvs.map(StringType, new StructType().add("v1", IntegerType)),
'keys.array(StringType))
+ val key = UnresolvedNamedLambdaVariable("key" :: Nil)
+ val lambda = LambdaFunction('kvs.getItem(key).getField("v1"), key :: Nil)
+ val query = rel
+ .limit(5)
+ .select('keys, 'kvs)
+ .limit(5)
+ .select(ArrayTransform('keys, lambda).as("a"))
+ .analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
+ test("SPARK-35636: do not push down extract value in higher order " +
+ "function that references both sides of a join") {
+ val left = LocalRelation('kvs.map(StringType, new StructType().add("v1",
IntegerType)))
+ val right = LocalRelation('keys.array(StringType))
+ val key = UnresolvedNamedLambdaVariable("key" :: Nil)
+ val lambda = LambdaFunction('kvs.getItem(key).getField("v1"), key :: Nil)
+ val query = left
+ .join(right, Cross, None)
+ .select(ArrayTransform('keys, lambda).as("a"))
+ .analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
}
object NestedColumnAliasingSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]