This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new af1ac1edc2a9 [SPARK-41049][SQL][FOLLOW-UP] Mark map related 
expressions as stateful expressions
af1ac1edc2a9 is described below

commit af1ac1edc2a96c9aba949e3100ddae37b6f0e5b2
Author: Rui Wang <[email protected]>
AuthorDate: Mon May 27 22:40:13 2024 -0700

    [SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful 
expressions
    
    ### What changes were proposed in this pull request?
    
    MapConcat contains a state so it is stateful:
    ```
    private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, 
dataType.valueType)
    ```
    
    Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and 
TransformKeys` need the same change.
    
    ### Why are the changes needed?
    
    Stateful expression should be marked as stateful.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    N/A
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46721 from amaliujia/statefulexpr.
    
    Authored-by: Rui Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/collectionOperations.scala  |  3 +++
 .../spark/sql/catalyst/expressions/complexTypeCreator.scala    |  6 ++++++
 .../spark/sql/catalyst/expressions/higherOrderFunctions.scala  |  2 ++
 .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala   | 10 +++++++++-
 4 files changed, 20 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 632e2f3d3e97..ea117f876550 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -713,6 +713,7 @@ case class MapConcat(children: Seq[Expression])
     }
   }
 
+  override def stateful: Boolean = true
   override def nullable: Boolean = children.exists(_.nullable)
 
   private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, 
dataType.valueType)
@@ -828,6 +829,8 @@ case class MapFromEntries(child: Expression)
 
   override def nullable: Boolean = child.nullable || nullEntries
 
+  override def stateful: Boolean = true
+
   @transient override lazy val dataType: MapType = dataTypeDetails.get._1
 
   override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4c0d00534060..167c02c0bafc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -245,6 +245,8 @@ case class CreateMap(children: Seq[Expression], 
useStringTypeWhenEmpty: Boolean)
 
   private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, 
dataType.valueType)
 
+  override def stateful: Boolean = true
+
   override def eval(input: InternalRow): Any = {
     var i = 0
     while (i < keys.length) {
@@ -320,6 +322,8 @@ case class MapFromArrays(left: Expression, right: 
Expression)
       valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
   }
 
+  override def stateful: Boolean = true
+
   private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, 
dataType.valueType)
 
   override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
@@ -568,6 +572,8 @@ case class StringToMap(text: Expression, pairDelim: 
Expression, keyValueDelim: E
     this(child, Literal(","), Literal(":"))
   }
 
+  override def stateful: Boolean = true
+
   override def first: Expression = text
   override def second: Expression = pairDelim
   override def third: Expression = keyValueDelim
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 896f3e9774f3..80bcf156133e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -920,6 +920,8 @@ case class TransformKeys(
 
   override def dataType: MapType = MapType(function.dataType, valueType, 
valueContainsNull)
 
+  override def stateful: Boolean = true
+
   override def checkInputDataTypes(): TypeCheckResult = {
     TypeUtils.checkForMapKeyType(function.dataType)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f11ad230ec16..760ee8026080 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Cast, EqualTo, ExpressionSet, GreaterThan, Literal, 
PythonUDF, ScalarSubquery, Uuid}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, 
Literal, PythonUDF, ScalarSubquery, Uuid}
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, 
LocalRelation, LogicalPlan, OneRowRelation}
@@ -2504,6 +2504,14 @@ class DataFrameSuite extends QueryTest
       assert(row.getInt(0).toString == row.getString(2))
       assert(row.getInt(0).toString == row.getString(3))
     }
+
+    val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value"))))
+    val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback
+    df.select(v3, v3, v4, v4).collect().foreach { row =>
+      assert(row.getMap(0).toString() == row.getMap(1).toString())
+      assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}")
+      assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}")
+    }
   }
 
   test("SPARK-45216: Non-deterministic functions with seed") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to