This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new f10fbb2  [SPARK-35636][SQL] Lambda keys should not be referenced 
outside of the lambda function
f10fbb2 is described below

commit f10fbb2055ecf4163b5b3d236e69138dfb228e1a
Author: Karen Feng <[email protected]>
AuthorDate: Fri Jun 4 15:44:32 2021 +0900

    [SPARK-35636][SQL] Lambda keys should not be referenced outside of the 
lambda function
    
    Sets `references` for `NamedLambdaVariable` and `LambdaFunction`.
    
    | Expression  | NamedLambdaVariable | LambdaFunction |
    | --- | --- | --- |
    | References before | None | All function references |
    | References after | self.toAttribute | Function references minus 
arguments' references |
    
    In `NestedColumnAliasing`, this means that `ExtractValue(ExtractValue(attr, 
lv: NamedLambdaVariable), ...)` now references both `attr` and `lv`, rather 
than just `attr`. As a result, it will not be included in the nested column 
references.
    
    Before, lambda key was referenced outside of lambda function.
    
    Before:
    ```
    Project [transform(keys#0, lambdafunction(_extract_v1#0, lambda key#0, 
false)) AS a#0]
    +- 'Join Cross
       :- Project [kvs#0[lambda key#0].v1 AS _extract_v1#0]
       :  +- LocalRelation <empty>, [kvs#0]
       +- LocalRelation <empty>, [keys#0]
    ```
    
    After:
    ```
    Project [transform(keys#418, lambdafunction(kvs#417[lambda key#420].v1, 
lambda key#420, false)) AS a#419]
    +- Join Cross
       :- LocalRelation <empty>, [kvs#417]
       +- LocalRelation <empty>, [keys#418]
    ```
    
    Before:
    ```
    Project [transform(keys#0, lambdafunction(kvs#0[lambda key#0].v1, lambda 
key#0, false)) AS a#0]
    +- GlobalLimit 5
      +- LocalLimit 5
        +- Project [keys#0, _extract_v1#0 AS _extract_v1#0]
          +- GlobalLimit 5
            +- LocalLimit 5
              +- Project [kvs#0[lambda key#0].v1 AS _extract_v1#0, keys#0]
                +- LocalRelation <empty>, [kvs#0, keys#0]
    ```
    
    After:
    ```
    Project [transform(keys#428, lambdafunction(kvs#427[lambda key#430].v1, 
lambda key#430, false)) AS a#429]
    +- GlobalLimit 5
      +- LocalLimit 5
        +- Project [keys#428, kvs#427]
          +- GlobalLimit 5
            +- LocalLimit 5
              +- LocalRelation <empty>, [kvs#427, keys#428]
    ```
    
    No
    
    Scala unit tests for the examples above
    
    Closes #32773 from karenfeng/SPARK-35636.
    
    Authored-by: Karen Feng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 53a758b51b79c52ac5d4bf3fad72765e55607f36)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../expressions/higherOrderFunctions.scala         | 10 +++++++
 .../optimizer/NestedColumnAliasingSuite.scala      | 32 +++++++++++++++++++++-
 2 files changed, 41 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a4e069d..0a6376f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -84,6 +84,9 @@ case class NamedLambdaVariable(
     AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, 
Seq.empty)
   }
 
+  // Check if this lambda variable is referenced outside the lambda function 
it is bound to
+  override def references: AttributeSet = AttributeSet(toAttribute)
+
   override def eval(input: InternalRow): Any = value.get
 
   override def toString: String = s"lambda $name#${exprId.id}$typeSuffix"
@@ -108,6 +111,13 @@ case class LambdaFunction(
   override def dataType: DataType = function.dataType
   override def nullable: Boolean = function.nullable
 
+  // Check if lambda variables bound to this lambda function are referenced in 
the wrong scope
+  override def references: AttributeSet = if (resolved) {
+    function.references -- AttributeSet(arguments.flatMap(_.references))
+  } else {
+    super.references
+  }
+
   lazy val bound: Boolean = arguments.forall(_.resolved)
 
   override def eval(input: InternalRow): Any = function.eval(input)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index c83ab37..08e442d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.SchemaPruningTest
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Cross
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 
 class NestedColumnAliasingSuite extends SchemaPruningTest {
 
@@ -684,6 +685,35 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
     ).analyze
     comparePlans(optimized2, expected2)
   }
+
+  test("SPARK-35636: do not push lambda key out of lambda function") {
+    val rel = LocalRelation(
+      'kvs.map(StringType, new StructType().add("v1", IntegerType)), 
'keys.array(StringType))
+    val key = UnresolvedNamedLambdaVariable("key" :: Nil)
+    val lambda = LambdaFunction('kvs.getItem(key).getField("v1"), key :: Nil)
+    val query = rel
+      .limit(5)
+      .select('keys, 'kvs)
+      .limit(5)
+      .select(ArrayTransform('keys, lambda).as("a"))
+      .analyze
+    val optimized = Optimize.execute(query)
+    comparePlans(optimized, query)
+  }
+
+  test("SPARK-35636: do not push down extract value in higher order " +
+    "function that references both sides of a join") {
+    val left = LocalRelation('kvs.map(StringType, new StructType().add("v1", 
IntegerType)))
+    val right = LocalRelation('keys.array(StringType))
+    val key = UnresolvedNamedLambdaVariable("key" :: Nil)
+    val lambda = LambdaFunction('kvs.getItem(key).getField("v1"), key :: Nil)
+    val query = left
+      .join(right, Cross, None)
+      .select(ArrayTransform('keys, lambda).as("a"))
+      .analyze
+    val optimized = Optimize.execute(query)
+    comparePlans(optimized, query)
+  }
 }
 
 object NestedColumnAliasingSuite {

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to